cognee 0.5.0__py3-none-any.whl → 0.5.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. cognee/api/client.py +5 -1
  2. cognee/api/v1/add/add.py +1 -2
  3. cognee/api/v1/cognify/code_graph_pipeline.py +119 -0
  4. cognee/api/v1/cognify/cognify.py +16 -24
  5. cognee/api/v1/cognify/routers/__init__.py +1 -0
  6. cognee/api/v1/cognify/routers/get_code_pipeline_router.py +90 -0
  7. cognee/api/v1/cognify/routers/get_cognify_router.py +1 -3
  8. cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
  9. cognee/api/v1/ontologies/ontologies.py +37 -12
  10. cognee/api/v1/ontologies/routers/get_ontology_router.py +25 -27
  11. cognee/api/v1/search/search.py +0 -4
  12. cognee/api/v1/ui/ui.py +68 -38
  13. cognee/context_global_variables.py +16 -61
  14. cognee/eval_framework/answer_generation/answer_generation_executor.py +0 -10
  15. cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
  16. cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +2 -0
  17. cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
  18. cognee/eval_framework/eval_config.py +2 -2
  19. cognee/eval_framework/modal_run_eval.py +28 -16
  20. cognee/infrastructure/databases/graph/config.py +0 -3
  21. cognee/infrastructure/databases/graph/get_graph_engine.py +0 -1
  22. cognee/infrastructure/databases/graph/graph_db_interface.py +0 -15
  23. cognee/infrastructure/databases/graph/kuzu/adapter.py +0 -228
  24. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +1 -80
  25. cognee/infrastructure/databases/utils/__init__.py +0 -3
  26. cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +48 -62
  27. cognee/infrastructure/databases/vector/config.py +0 -2
  28. cognee/infrastructure/databases/vector/create_vector_engine.py +0 -1
  29. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +6 -8
  30. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +7 -9
  31. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +10 -11
  32. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +544 -0
  33. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +0 -2
  34. cognee/infrastructure/databases/vector/vector_db_interface.py +0 -35
  35. cognee/infrastructure/files/storage/s3_config.py +0 -2
  36. cognee/infrastructure/llm/LLMGateway.py +2 -5
  37. cognee/infrastructure/llm/config.py +0 -35
  38. cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
  39. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +8 -23
  40. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +16 -17
  41. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +37 -40
  42. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +36 -39
  43. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +1 -19
  44. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +9 -11
  45. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +21 -23
  46. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +34 -42
  47. cognee/modules/cognify/config.py +0 -2
  48. cognee/modules/data/deletion/prune_system.py +2 -52
  49. cognee/modules/data/methods/delete_dataset.py +0 -26
  50. cognee/modules/engine/models/__init__.py +0 -1
  51. cognee/modules/graph/cognee_graph/CogneeGraph.py +37 -85
  52. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +3 -8
  53. cognee/modules/memify/memify.py +7 -1
  54. cognee/modules/pipelines/operations/pipeline.py +2 -18
  55. cognee/modules/retrieval/__init__.py +1 -1
  56. cognee/modules/retrieval/code_retriever.py +232 -0
  57. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +0 -4
  58. cognee/modules/retrieval/graph_completion_cot_retriever.py +0 -4
  59. cognee/modules/retrieval/graph_completion_retriever.py +0 -10
  60. cognee/modules/retrieval/graph_summary_completion_retriever.py +0 -4
  61. cognee/modules/retrieval/temporal_retriever.py +0 -4
  62. cognee/modules/retrieval/utils/brute_force_triplet_search.py +10 -42
  63. cognee/modules/run_custom_pipeline/run_custom_pipeline.py +1 -8
  64. cognee/modules/search/methods/get_search_type_tools.py +8 -54
  65. cognee/modules/search/methods/no_access_control_search.py +0 -4
  66. cognee/modules/search/methods/search.py +0 -21
  67. cognee/modules/search/types/SearchType.py +1 -1
  68. cognee/modules/settings/get_settings.py +0 -19
  69. cognee/modules/users/methods/get_authenticated_user.py +2 -2
  70. cognee/modules/users/models/DatasetDatabase.py +3 -15
  71. cognee/shared/logging_utils.py +0 -4
  72. cognee/tasks/code/enrich_dependency_graph_checker.py +35 -0
  73. cognee/tasks/code/get_local_dependencies_checker.py +20 -0
  74. cognee/tasks/code/get_repo_dependency_graph_checker.py +35 -0
  75. cognee/tasks/documents/__init__.py +1 -0
  76. cognee/tasks/documents/check_permissions_on_dataset.py +26 -0
  77. cognee/tasks/graph/extract_graph_from_data.py +10 -9
  78. cognee/tasks/repo_processor/__init__.py +2 -0
  79. cognee/tasks/repo_processor/get_local_dependencies.py +335 -0
  80. cognee/tasks/repo_processor/get_non_code_files.py +158 -0
  81. cognee/tasks/repo_processor/get_repo_file_dependencies.py +243 -0
  82. cognee/tasks/storage/add_data_points.py +2 -142
  83. cognee/tests/test_cognee_server_start.py +4 -2
  84. cognee/tests/test_conversation_history.py +1 -23
  85. cognee/tests/test_delete_bmw_example.py +60 -0
  86. cognee/tests/test_search_db.py +1 -37
  87. cognee/tests/unit/api/test_ontology_endpoint.py +89 -77
  88. cognee/tests/unit/infrastructure/mock_embedding_engine.py +7 -3
  89. cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +5 -0
  90. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
  91. cognee/tests/unit/modules/graph/cognee_graph_test.py +0 -406
  92. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/METADATA +89 -76
  93. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/RECORD +97 -118
  94. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/WHEEL +1 -1
  95. cognee/api/v1/ui/node_setup.py +0 -360
  96. cognee/api/v1/ui/npm_utils.py +0 -50
  97. cognee/eval_framework/Dockerfile +0 -29
  98. cognee/infrastructure/databases/dataset_database_handler/__init__.py +0 -3
  99. cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +0 -80
  100. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +0 -18
  101. cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +0 -10
  102. cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +0 -81
  103. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +0 -168
  104. cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +0 -10
  105. cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +0 -10
  106. cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +0 -30
  107. cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +0 -50
  108. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +0 -5
  109. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +0 -153
  110. cognee/memify_pipelines/create_triplet_embeddings.py +0 -53
  111. cognee/modules/engine/models/Triplet.py +0 -9
  112. cognee/modules/retrieval/register_retriever.py +0 -10
  113. cognee/modules/retrieval/registered_community_retrievers.py +0 -1
  114. cognee/modules/retrieval/triplet_retriever.py +0 -182
  115. cognee/shared/rate_limiting.py +0 -30
  116. cognee/tasks/memify/get_triplet_datapoints.py +0 -289
  117. cognee/tests/integration/retrieval/test_triplet_retriever.py +0 -84
  118. cognee/tests/integration/tasks/test_add_data_points.py +0 -139
  119. cognee/tests/integration/tasks/test_get_triplet_datapoints.py +0 -69
  120. cognee/tests/test_dataset_database_handler.py +0 -137
  121. cognee/tests/test_dataset_delete.py +0 -76
  122. cognee/tests/test_edge_centered_payload.py +0 -170
  123. cognee/tests/test_pipeline_cache.py +0 -164
  124. cognee/tests/unit/infrastructure/llm/test_llm_config.py +0 -46
  125. cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +0 -214
  126. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +0 -608
  127. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +0 -83
  128. cognee/tests/unit/tasks/storage/test_add_data_points.py +0 -288
  129. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/entry_points.txt +0 -0
  130. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/LICENSE +0 -0
  131. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,164 +0,0 @@
1
- """
2
- Test suite for the pipeline_cache feature in Cognee pipelines.
3
-
4
- This module tests the behavior of the `pipeline_cache` parameter which controls
5
- whether a pipeline should skip re-execution when it has already been completed
6
- for the same dataset.
7
-
8
- Architecture Overview:
9
- ---------------------
10
- The pipeline_cache mechanism works at the dataset level:
11
- 1. When a pipeline runs, it logs its status (INITIATED -> STARTED -> COMPLETED)
12
- 2. Before each run, `check_pipeline_run_qualification()` checks the pipeline status
13
- 3. If `use_pipeline_cache=True` and status is COMPLETED/STARTED, the pipeline skips
14
- 4. If `use_pipeline_cache=False`, the pipeline always re-executes regardless of status
15
- """
16
-
17
- import pytest
18
-
19
- import cognee
20
- from cognee.modules.pipelines.tasks.task import Task
21
- from cognee.modules.pipelines import run_pipeline
22
- from cognee.modules.users.methods import get_default_user
23
-
24
- from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
25
- reset_dataset_pipeline_run_status,
26
- )
27
- from cognee.infrastructure.databases.relational import create_db_and_tables
28
-
29
-
30
- class ExecutionCounter:
31
- """Helper class to track task execution counts."""
32
-
33
- def __init__(self):
34
- self.count = 0
35
-
36
-
37
- async def create_counting_task(data, counter: ExecutionCounter):
38
- """Create a task that increments a counter from the ExecutionCounter instance when executed."""
39
- counter.count += 1
40
- return counter
41
-
42
-
43
- class TestPipelineCache:
44
- """Tests for basic pipeline_cache on/off behavior."""
45
-
46
- @pytest.mark.asyncio
47
- async def test_pipeline_cache_off_allows_reexecution(self):
48
- """
49
- Test that with use_pipeline_cache=False, the pipeline re-executes
50
- even when it has already completed for the dataset.
51
-
52
- Expected behavior:
53
- - First run: Pipeline executes fully, task runs once
54
- - Second run: Pipeline executes again, task runs again (total: 2 times)
55
- """
56
- await cognee.prune.prune_data()
57
- await cognee.prune.prune_system(metadata=True)
58
- await create_db_and_tables()
59
-
60
- counter = ExecutionCounter()
61
- user = await get_default_user()
62
-
63
- tasks = [Task(create_counting_task, counter=counter)]
64
-
65
- # First run
66
- pipeline_results_1 = []
67
- async for result in run_pipeline(
68
- tasks=tasks,
69
- datasets="test_dataset_cache_off",
70
- data=["sample data"], # Data is necessary to trigger processing
71
- user=user,
72
- pipeline_name="test_cache_off_pipeline",
73
- use_pipeline_cache=False,
74
- ):
75
- pipeline_results_1.append(result)
76
-
77
- first_run_count = counter.count
78
- assert first_run_count >= 1, "Task should have executed at least once on first run"
79
-
80
- # Second run with pipeline_cache=False
81
- pipeline_results_2 = []
82
- async for result in run_pipeline(
83
- tasks=tasks,
84
- datasets="test_dataset_cache_off",
85
- data=["sample data"], # Data is necessary to trigger processing
86
- user=user,
87
- pipeline_name="test_cache_off_pipeline",
88
- use_pipeline_cache=False,
89
- ):
90
- pipeline_results_2.append(result)
91
-
92
- second_run_count = counter.count
93
- assert second_run_count > first_run_count, (
94
- f"With pipeline_cache=False, task should re-execute. "
95
- f"First run: {first_run_count}, After second run: {second_run_count}"
96
- )
97
-
98
- @pytest.mark.asyncio
99
- async def test_reset_pipeline_status_allows_reexecution_with_cache(self):
100
- """
101
- Test that resetting pipeline status allows re-execution even with
102
- pipeline_cache=True.
103
- """
104
- await cognee.prune.prune_data()
105
- await cognee.prune.prune_system(metadata=True)
106
- await create_db_and_tables()
107
-
108
- counter = ExecutionCounter()
109
- user = await get_default_user()
110
- dataset_name = "reset_status_test"
111
- pipeline_name = "test_reset_pipeline"
112
-
113
- tasks = [Task(create_counting_task, counter=counter)]
114
-
115
- # First run
116
- pipeline_result = []
117
- async for result in run_pipeline(
118
- tasks=tasks,
119
- datasets=dataset_name,
120
- user=user,
121
- data=["sample data"], # Data is necessary to trigger processing
122
- pipeline_name=pipeline_name,
123
- use_pipeline_cache=True,
124
- ):
125
- pipeline_result.append(result)
126
-
127
- first_run_count = counter.count
128
- assert first_run_count >= 1
129
-
130
- # Second run without reset - should skip
131
- async for _ in run_pipeline(
132
- tasks=tasks,
133
- datasets=dataset_name,
134
- user=user,
135
- data=["sample data"], # Data is necessary to trigger processing
136
- pipeline_name=pipeline_name,
137
- use_pipeline_cache=True,
138
- ):
139
- pass
140
-
141
- after_second_run = counter.count
142
- assert after_second_run == first_run_count, "Should have skipped due to cache"
143
-
144
- # Reset the pipeline status
145
- await reset_dataset_pipeline_run_status(
146
- pipeline_result[0].dataset_id, user, pipeline_names=[pipeline_name]
147
- )
148
-
149
- # Third run after reset - should execute
150
- async for _ in run_pipeline(
151
- tasks=tasks,
152
- datasets=dataset_name,
153
- user=user,
154
- data=["sample data"], # Data is necessary to trigger processing
155
- pipeline_name=pipeline_name,
156
- use_pipeline_cache=True,
157
- ):
158
- pass
159
-
160
- after_reset_run = counter.count
161
- assert after_reset_run > after_second_run, (
162
- f"After reset, pipeline should re-execute. "
163
- f"Before reset: {after_second_run}, After reset run: {after_reset_run}"
164
- )
@@ -1,46 +0,0 @@
1
- import pytest
2
-
3
- from cognee.infrastructure.llm.config import LLMConfig
4
-
5
-
6
- def test_strip_quotes_from_strings():
7
- """
8
- Test if the LLMConfig.strip_quotes_from_strings model validator behaves as expected.
9
- """
10
- config = LLMConfig(
11
- # Strings with surrounding double quotes ("value" → value)
12
- llm_api_key='"double_value"',
13
- # Strings with surrounding single quotes ('value' → value)
14
- llm_endpoint="'single_value'",
15
- # Strings without quotes (value → value)
16
- llm_api_version="no_quotes_value",
17
- # Empty quoted strings ("" → empty string)
18
- fallback_model='""',
19
- # None values (should remain None)
20
- baml_llm_api_key=None,
21
- # Mixed quotes ("value' → unchanged)
22
- fallback_endpoint="\"mixed_quote'",
23
- # Strings with internal quotes ("internal\"quotes" → internal"quotes")
24
- baml_llm_model='"internal"quotes"',
25
- )
26
-
27
- # Strings with surrounding double quotes ("value" → value)
28
- assert config.llm_api_key == "double_value"
29
-
30
- # Strings with surrounding single quotes ('value' → value)
31
- assert config.llm_endpoint == "single_value"
32
-
33
- # Strings without quotes (value → value)
34
- assert config.llm_api_version == "no_quotes_value"
35
-
36
- # Empty quoted strings ("" → empty string)
37
- assert config.fallback_model == ""
38
-
39
- # None values (should remain None)
40
- assert config.baml_llm_api_key is None
41
-
42
- # Mixed quotes ("value' → unchanged)
43
- assert config.fallback_endpoint == "\"mixed_quote'"
44
-
45
- # Strings with internal quotes ("internal\"quotes" → internal"quotes")
46
- assert config.baml_llm_model == 'internal"quotes'
@@ -1,214 +0,0 @@
1
- import sys
2
- import pytest
3
- from unittest.mock import AsyncMock, patch
4
-
5
- from cognee.tasks.memify.get_triplet_datapoints import get_triplet_datapoints
6
- from cognee.modules.engine.models import Triplet
7
- from cognee.modules.engine.models.Entity import Entity
8
- from cognee.infrastructure.engine import DataPoint
9
- from cognee.modules.graph.models.EdgeType import EdgeType
10
-
11
-
12
- get_triplet_datapoints_module = sys.modules["cognee.tasks.memify.get_triplet_datapoints"]
13
-
14
-
15
- @pytest.fixture
16
- def mock_graph_engine():
17
- """Create a mock graph engine with get_triplets_batch method."""
18
- engine = AsyncMock()
19
- engine.get_triplets_batch = AsyncMock()
20
- return engine
21
-
22
-
23
- @pytest.mark.asyncio
24
- async def test_get_triplet_datapoints_success(mock_graph_engine):
25
- """Test successful extraction of triplet datapoints."""
26
- mock_triplets_batch = [
27
- {
28
- "start_node": {
29
- "id": "node1",
30
- "type": "Entity",
31
- "name": "Alice",
32
- "description": "A person",
33
- },
34
- "end_node": {
35
- "id": "node2",
36
- "type": "Entity",
37
- "name": "Bob",
38
- "description": "Another person",
39
- },
40
- "relationship_properties": {
41
- "relationship_name": "knows",
42
- },
43
- }
44
- ]
45
-
46
- mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
47
-
48
- with (
49
- patch.object(
50
- get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
51
- ),
52
- patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
53
- ):
54
- mock_get_subclasses.return_value = [Triplet, EdgeType, Entity]
55
-
56
- triplets = []
57
- async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
58
- triplets.append(triplet)
59
-
60
- assert len(triplets) == 1
61
- assert isinstance(triplets[0], Triplet)
62
- assert triplets[0].from_node_id == "node1"
63
- assert triplets[0].to_node_id == "node2"
64
- assert "Alice" in triplets[0].text
65
- assert "knows" in triplets[0].text
66
- assert "Bob" in triplets[0].text
67
-
68
-
69
- @pytest.mark.asyncio
70
- async def test_get_triplet_datapoints_edge_text_priority_and_fallback(mock_graph_engine):
71
- """Test that edge_text is prioritized over relationship_name, and fallback works."""
72
-
73
- class MockEntity(DataPoint):
74
- name: str
75
- metadata: dict = {"index_fields": ["name"]}
76
-
77
- mock_triplets_batch = [
78
- {
79
- "start_node": {"id": "node1", "type": "Entity", "name": "Alice"},
80
- "end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
81
- "relationship_properties": {
82
- "relationship_name": "knows",
83
- "edge_text": "has a close friendship with",
84
- },
85
- },
86
- {
87
- "start_node": {"id": "node3", "type": "Entity", "name": "Charlie"},
88
- "end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
89
- "relationship_properties": {
90
- "relationship_name": "works_with",
91
- },
92
- },
93
- ]
94
-
95
- mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
96
-
97
- with (
98
- patch.object(
99
- get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
100
- ),
101
- patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
102
- ):
103
- mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
104
-
105
- triplets = []
106
- async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
107
- triplets.append(triplet)
108
-
109
- assert len(triplets) == 2
110
- assert "has a close friendship with" in triplets[0].text
111
- assert "knows" not in triplets[0].text
112
- assert "works_with" in triplets[1].text
113
-
114
-
115
- @pytest.mark.asyncio
116
- async def test_get_triplet_datapoints_skips_missing_node_ids(mock_graph_engine):
117
- """Test that triplets with missing node IDs are skipped."""
118
-
119
- class MockEntity(DataPoint):
120
- name: str
121
- metadata: dict = {"index_fields": ["name"]}
122
-
123
- mock_triplets_batch = [
124
- {
125
- "start_node": {"id": "", "type": "Entity", "name": "Alice"},
126
- "end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
127
- "relationship_properties": {"relationship_name": "knows"},
128
- },
129
- {
130
- "start_node": {"id": "node3", "type": "Entity", "name": "Charlie"},
131
- "end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
132
- "relationship_properties": {"relationship_name": "works_with"},
133
- },
134
- ]
135
-
136
- mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
137
-
138
- with (
139
- patch.object(
140
- get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
141
- ),
142
- patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
143
- ):
144
- mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
145
-
146
- triplets = []
147
- async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
148
- triplets.append(triplet)
149
-
150
- assert len(triplets) == 1
151
- assert triplets[0].from_node_id == "node3"
152
-
153
-
154
- @pytest.mark.asyncio
155
- async def test_get_triplet_datapoints_error_handling(mock_graph_engine):
156
- """Test that errors are handled correctly - invalid data is skipped, query errors propagate."""
157
-
158
- class MockEntity(DataPoint):
159
- name: str
160
- metadata: dict = {"index_fields": ["name"]}
161
-
162
- mock_triplets_batch = [
163
- {
164
- "start_node": {"id": "node1", "type": "Entity", "name": "Alice"},
165
- "end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
166
- "relationship_properties": {"relationship_name": "knows"},
167
- },
168
- {
169
- "start_node": None,
170
- "end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
171
- "relationship_properties": {"relationship_name": "works_with"},
172
- },
173
- ]
174
-
175
- mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
176
-
177
- with (
178
- patch.object(
179
- get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
180
- ),
181
- patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
182
- ):
183
- mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
184
-
185
- triplets = []
186
- async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
187
- triplets.append(triplet)
188
-
189
- assert len(triplets) == 1
190
- assert triplets[0].from_node_id == "node1"
191
-
192
- mock_graph_engine.get_triplets_batch.side_effect = Exception("Database connection error")
193
-
194
- with patch.object(
195
- get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
196
- ):
197
- triplets = []
198
- with pytest.raises(Exception, match="Database connection error"):
199
- async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
200
- triplets.append(triplet)
201
-
202
-
203
- @pytest.mark.asyncio
204
- async def test_get_triplet_datapoints_no_get_triplets_batch_method(mock_graph_engine):
205
- """Test that NotImplementedError is raised when graph engine lacks get_triplets_batch."""
206
- del mock_graph_engine.get_triplets_batch
207
-
208
- with patch.object(
209
- get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
210
- ):
211
- triplets = []
212
- with pytest.raises(NotImplementedError, match="does not support get_triplets_batch"):
213
- async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
214
- triplets.append(triplet)