cognee 0.4.1__py3-none-any.whl → 0.5.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +1 -0
- cognee/api/client.py +8 -0
- cognee/api/v1/add/routers/get_add_router.py +3 -1
- cognee/api/v1/cognify/routers/get_cognify_router.py +28 -1
- cognee/api/v1/ontologies/__init__.py +4 -0
- cognee/api/v1/ontologies/ontologies.py +183 -0
- cognee/api/v1/ontologies/routers/__init__.py +0 -0
- cognee/api/v1/ontologies/routers/get_ontology_router.py +107 -0
- cognee/api/v1/permissions/routers/get_permissions_router.py +41 -1
- cognee/cli/commands/cognify_command.py +8 -1
- cognee/cli/config.py +1 -1
- cognee/context_global_variables.py +41 -9
- cognee/infrastructure/databases/cache/config.py +3 -1
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +151 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +20 -10
- cognee/infrastructure/databases/exceptions/exceptions.py +16 -0
- cognee/infrastructure/databases/graph/config.py +4 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +2 -0
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +9 -0
- cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +37 -3
- cognee/infrastructure/databases/vector/config.py +3 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +5 -1
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +1 -4
- cognee/infrastructure/engine/models/Edge.py +13 -1
- cognee/infrastructure/files/utils/guess_file_type.py +4 -0
- cognee/infrastructure/llm/config.py +2 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +5 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +7 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +7 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +8 -16
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +12 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +13 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +5 -2
- cognee/infrastructure/loaders/LoaderEngine.py +1 -0
- cognee/infrastructure/loaders/core/__init__.py +2 -1
- cognee/infrastructure/loaders/core/csv_loader.py +93 -0
- cognee/infrastructure/loaders/core/text_loader.py +1 -2
- cognee/infrastructure/loaders/external/advanced_pdf_loader.py +0 -9
- cognee/infrastructure/loaders/supported_loaders.py +2 -1
- cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py +55 -0
- cognee/modules/chunking/CsvChunker.py +35 -0
- cognee/modules/chunking/models/DocumentChunk.py +2 -1
- cognee/modules/chunking/text_chunker_with_overlap.py +124 -0
- cognee/modules/data/methods/__init__.py +1 -0
- cognee/modules/data/methods/create_dataset.py +4 -2
- cognee/modules/data/methods/get_dataset_ids.py +5 -1
- cognee/modules/data/methods/get_unique_data_id.py +68 -0
- cognee/modules/data/methods/get_unique_dataset_id.py +66 -4
- cognee/modules/data/models/Dataset.py +2 -0
- cognee/modules/data/processing/document_types/CsvDocument.py +33 -0
- cognee/modules/data/processing/document_types/__init__.py +1 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +4 -2
- cognee/modules/graph/utils/expand_with_nodes_and_edges.py +19 -2
- cognee/modules/graph/utils/resolve_edges_to_text.py +48 -49
- cognee/modules/ingestion/identify.py +4 -4
- cognee/modules/notebooks/operations/run_in_local_sandbox.py +3 -0
- cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +55 -23
- cognee/modules/pipelines/operations/run_tasks_data_item.py +1 -1
- cognee/modules/retrieval/EntityCompletionRetriever.py +10 -3
- cognee/modules/retrieval/base_graph_retriever.py +7 -3
- cognee/modules/retrieval/base_retriever.py +7 -3
- cognee/modules/retrieval/completion_retriever.py +11 -4
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +6 -2
- cognee/modules/retrieval/graph_completion_cot_retriever.py +14 -51
- cognee/modules/retrieval/graph_completion_retriever.py +4 -1
- cognee/modules/retrieval/temporal_retriever.py +9 -2
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +1 -1
- cognee/modules/retrieval/utils/completion.py +2 -22
- cognee/modules/run_custom_pipeline/__init__.py +1 -0
- cognee/modules/run_custom_pipeline/run_custom_pipeline.py +69 -0
- cognee/modules/search/methods/search.py +5 -3
- cognee/modules/users/methods/create_user.py +12 -27
- cognee/modules/users/methods/get_authenticated_user.py +2 -1
- cognee/modules/users/methods/get_default_user.py +4 -2
- cognee/modules/users/methods/get_user.py +1 -1
- cognee/modules/users/methods/get_user_by_email.py +1 -1
- cognee/modules/users/models/DatasetDatabase.py +9 -0
- cognee/modules/users/models/Tenant.py +6 -7
- cognee/modules/users/models/User.py +6 -5
- cognee/modules/users/models/UserTenant.py +12 -0
- cognee/modules/users/models/__init__.py +1 -0
- cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py +13 -13
- cognee/modules/users/roles/methods/add_user_to_role.py +3 -1
- cognee/modules/users/tenants/methods/__init__.py +1 -0
- cognee/modules/users/tenants/methods/add_user_to_tenant.py +21 -12
- cognee/modules/users/tenants/methods/create_tenant.py +22 -8
- cognee/modules/users/tenants/methods/select_tenant.py +62 -0
- cognee/shared/logging_utils.py +2 -0
- cognee/tasks/chunks/__init__.py +1 -0
- cognee/tasks/chunks/chunk_by_row.py +94 -0
- cognee/tasks/documents/classify_documents.py +2 -0
- cognee/tasks/feedback/generate_improved_answers.py +3 -3
- cognee/tasks/ingestion/ingest_data.py +1 -1
- cognee/tasks/memify/__init__.py +2 -0
- cognee/tasks/memify/cognify_session.py +41 -0
- cognee/tasks/memify/extract_user_sessions.py +73 -0
- cognee/tasks/storage/index_data_points.py +33 -22
- cognee/tasks/storage/index_graph_edges.py +37 -57
- cognee/tests/integration/documents/CsvDocument_test.py +70 -0
- cognee/tests/tasks/entity_extraction/entity_extraction_test.py +1 -1
- cognee/tests/test_add_docling_document.py +2 -2
- cognee/tests/test_cognee_server_start.py +84 -1
- cognee/tests/test_conversation_history.py +45 -4
- cognee/tests/test_data/example_with_header.csv +3 -0
- cognee/tests/test_delete_bmw_example.py +60 -0
- cognee/tests/test_edge_ingestion.py +27 -0
- cognee/tests/test_feedback_enrichment.py +1 -1
- cognee/tests/test_library.py +6 -4
- cognee/tests/test_load.py +62 -0
- cognee/tests/test_multi_tenancy.py +165 -0
- cognee/tests/test_parallel_databases.py +2 -0
- cognee/tests/test_relational_db_migration.py +54 -2
- cognee/tests/test_search_db.py +7 -1
- cognee/tests/unit/api/test_conditional_authentication_endpoints.py +12 -3
- cognee/tests/unit/api/test_ontology_endpoint.py +264 -0
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +5 -0
- cognee/tests/unit/infrastructure/databases/test_index_data_points.py +27 -0
- cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +14 -16
- cognee/tests/unit/modules/chunking/test_text_chunker.py +248 -0
- cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py +324 -0
- cognee/tests/unit/modules/memify_tasks/test_cognify_session.py +111 -0
- cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py +175 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +0 -51
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +1 -0
- cognee/tests/unit/modules/retrieval/structured_output_test.py +204 -0
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +1 -1
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +0 -1
- cognee/tests/unit/modules/users/test_conditional_authentication.py +0 -63
- cognee/tests/unit/processing/chunks/chunk_by_row_test.py +52 -0
- {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/METADATA +88 -71
- {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/RECORD +135 -104
- {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/WHEEL +1 -1
- {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/entry_points.txt +0 -1
- {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import pytest
|
|
3
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
4
|
+
|
|
5
|
+
from cognee.tasks.memify.extract_user_sessions import extract_user_sessions
|
|
6
|
+
from cognee.exceptions import CogneeSystemError
|
|
7
|
+
from cognee.modules.users.models import User
|
|
8
|
+
|
|
9
|
+
# Get the actual module object (not the function) for patching
|
|
10
|
+
extract_user_sessions_module = sys.modules["cognee.tasks.memify.extract_user_sessions"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def mock_user():
|
|
15
|
+
"""Create a mock user."""
|
|
16
|
+
user = MagicMock(spec=User)
|
|
17
|
+
user.id = "test-user-123"
|
|
18
|
+
return user
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@pytest.fixture
|
|
22
|
+
def mock_qa_data():
|
|
23
|
+
"""Create mock Q&A data."""
|
|
24
|
+
return [
|
|
25
|
+
{
|
|
26
|
+
"question": "What is cognee?",
|
|
27
|
+
"context": "context about cognee",
|
|
28
|
+
"answer": "Cognee is a knowledge graph solution",
|
|
29
|
+
"time": "2025-01-01T12:00:00",
|
|
30
|
+
},
|
|
31
|
+
{
|
|
32
|
+
"question": "How does it work?",
|
|
33
|
+
"context": "how it works context",
|
|
34
|
+
"answer": "It processes data and creates graphs",
|
|
35
|
+
"time": "2025-01-01T12:05:00",
|
|
36
|
+
},
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@pytest.mark.asyncio
|
|
41
|
+
async def test_extract_user_sessions_success(mock_user, mock_qa_data):
|
|
42
|
+
"""Test successful extraction of sessions."""
|
|
43
|
+
mock_cache_engine = AsyncMock()
|
|
44
|
+
mock_cache_engine.get_all_qas.return_value = mock_qa_data
|
|
45
|
+
|
|
46
|
+
with (
|
|
47
|
+
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
|
48
|
+
patch.object(
|
|
49
|
+
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
|
50
|
+
),
|
|
51
|
+
):
|
|
52
|
+
mock_session_user.get.return_value = mock_user
|
|
53
|
+
|
|
54
|
+
sessions = []
|
|
55
|
+
async for session in extract_user_sessions([{}], session_ids=["test_session"]):
|
|
56
|
+
sessions.append(session)
|
|
57
|
+
|
|
58
|
+
assert len(sessions) == 1
|
|
59
|
+
assert "Session ID: test_session" in sessions[0]
|
|
60
|
+
assert "Question: What is cognee?" in sessions[0]
|
|
61
|
+
assert "Answer: Cognee is a knowledge graph solution" in sessions[0]
|
|
62
|
+
assert "Question: How does it work?" in sessions[0]
|
|
63
|
+
assert "Answer: It processes data and creates graphs" in sessions[0]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@pytest.mark.asyncio
|
|
67
|
+
async def test_extract_user_sessions_multiple_sessions(mock_user, mock_qa_data):
|
|
68
|
+
"""Test extraction of multiple sessions."""
|
|
69
|
+
mock_cache_engine = AsyncMock()
|
|
70
|
+
mock_cache_engine.get_all_qas.return_value = mock_qa_data
|
|
71
|
+
|
|
72
|
+
with (
|
|
73
|
+
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
|
74
|
+
patch.object(
|
|
75
|
+
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
|
76
|
+
),
|
|
77
|
+
):
|
|
78
|
+
mock_session_user.get.return_value = mock_user
|
|
79
|
+
|
|
80
|
+
sessions = []
|
|
81
|
+
async for session in extract_user_sessions([{}], session_ids=["session1", "session2"]):
|
|
82
|
+
sessions.append(session)
|
|
83
|
+
|
|
84
|
+
assert len(sessions) == 2
|
|
85
|
+
assert mock_cache_engine.get_all_qas.call_count == 2
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@pytest.mark.asyncio
|
|
89
|
+
async def test_extract_user_sessions_no_data(mock_user, mock_qa_data):
|
|
90
|
+
"""Test extraction handles empty data parameter."""
|
|
91
|
+
mock_cache_engine = AsyncMock()
|
|
92
|
+
mock_cache_engine.get_all_qas.return_value = mock_qa_data
|
|
93
|
+
|
|
94
|
+
with (
|
|
95
|
+
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
|
96
|
+
patch.object(
|
|
97
|
+
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
|
98
|
+
),
|
|
99
|
+
):
|
|
100
|
+
mock_session_user.get.return_value = mock_user
|
|
101
|
+
|
|
102
|
+
sessions = []
|
|
103
|
+
async for session in extract_user_sessions(None, session_ids=["test_session"]):
|
|
104
|
+
sessions.append(session)
|
|
105
|
+
|
|
106
|
+
assert len(sessions) == 1
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@pytest.mark.asyncio
|
|
110
|
+
async def test_extract_user_sessions_no_session_ids(mock_user):
|
|
111
|
+
"""Test extraction handles no session IDs provided."""
|
|
112
|
+
mock_cache_engine = AsyncMock()
|
|
113
|
+
|
|
114
|
+
with (
|
|
115
|
+
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
|
116
|
+
patch.object(
|
|
117
|
+
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
|
118
|
+
),
|
|
119
|
+
):
|
|
120
|
+
mock_session_user.get.return_value = mock_user
|
|
121
|
+
|
|
122
|
+
sessions = []
|
|
123
|
+
async for session in extract_user_sessions([{}], session_ids=None):
|
|
124
|
+
sessions.append(session)
|
|
125
|
+
|
|
126
|
+
assert len(sessions) == 0
|
|
127
|
+
mock_cache_engine.get_all_qas.assert_not_called()
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@pytest.mark.asyncio
|
|
131
|
+
async def test_extract_user_sessions_empty_qa_data(mock_user):
|
|
132
|
+
"""Test extraction handles empty Q&A data."""
|
|
133
|
+
mock_cache_engine = AsyncMock()
|
|
134
|
+
mock_cache_engine.get_all_qas.return_value = []
|
|
135
|
+
|
|
136
|
+
with (
|
|
137
|
+
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
|
138
|
+
patch.object(
|
|
139
|
+
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
|
140
|
+
),
|
|
141
|
+
):
|
|
142
|
+
mock_session_user.get.return_value = mock_user
|
|
143
|
+
|
|
144
|
+
sessions = []
|
|
145
|
+
async for session in extract_user_sessions([{}], session_ids=["empty_session"]):
|
|
146
|
+
sessions.append(session)
|
|
147
|
+
|
|
148
|
+
assert len(sessions) == 0
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@pytest.mark.asyncio
|
|
152
|
+
async def test_extract_user_sessions_cache_error_handling(mock_user, mock_qa_data):
|
|
153
|
+
"""Test extraction continues on cache error for specific session."""
|
|
154
|
+
mock_cache_engine = AsyncMock()
|
|
155
|
+
mock_cache_engine.get_all_qas.side_effect = [
|
|
156
|
+
mock_qa_data,
|
|
157
|
+
Exception("Cache error"),
|
|
158
|
+
mock_qa_data,
|
|
159
|
+
]
|
|
160
|
+
|
|
161
|
+
with (
|
|
162
|
+
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
|
163
|
+
patch.object(
|
|
164
|
+
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
|
165
|
+
),
|
|
166
|
+
):
|
|
167
|
+
mock_session_user.get.return_value = mock_user
|
|
168
|
+
|
|
169
|
+
sessions = []
|
|
170
|
+
async for session in extract_user_sessions(
|
|
171
|
+
[{}], session_ids=["session1", "session2", "session3"]
|
|
172
|
+
):
|
|
173
|
+
sessions.append(session)
|
|
174
|
+
|
|
175
|
+
assert len(sessions) == 2
|
|
@@ -2,7 +2,6 @@ import os
|
|
|
2
2
|
import pytest
|
|
3
3
|
import pathlib
|
|
4
4
|
from typing import Optional, Union
|
|
5
|
-
from pydantic import BaseModel
|
|
6
5
|
|
|
7
6
|
import cognee
|
|
8
7
|
from cognee.low_level import setup, DataPoint
|
|
@@ -11,11 +10,6 @@ from cognee.tasks.storage import add_data_points
|
|
|
11
10
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
class TestAnswer(BaseModel):
|
|
15
|
-
answer: str
|
|
16
|
-
explanation: str
|
|
17
|
-
|
|
18
|
-
|
|
19
13
|
class TestGraphCompletionCoTRetriever:
|
|
20
14
|
@pytest.mark.asyncio
|
|
21
15
|
async def test_graph_completion_cot_context_simple(self):
|
|
@@ -174,48 +168,3 @@ class TestGraphCompletionCoTRetriever:
|
|
|
174
168
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
|
175
169
|
"Answer must contain only non-empty strings"
|
|
176
170
|
)
|
|
177
|
-
|
|
178
|
-
@pytest.mark.asyncio
|
|
179
|
-
async def test_get_structured_completion(self):
|
|
180
|
-
system_directory_path = os.path.join(
|
|
181
|
-
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
|
|
182
|
-
)
|
|
183
|
-
cognee.config.system_root_directory(system_directory_path)
|
|
184
|
-
data_directory_path = os.path.join(
|
|
185
|
-
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
|
|
186
|
-
)
|
|
187
|
-
cognee.config.data_root_directory(data_directory_path)
|
|
188
|
-
|
|
189
|
-
await cognee.prune.prune_data()
|
|
190
|
-
await cognee.prune.prune_system(metadata=True)
|
|
191
|
-
await setup()
|
|
192
|
-
|
|
193
|
-
class Company(DataPoint):
|
|
194
|
-
name: str
|
|
195
|
-
|
|
196
|
-
class Person(DataPoint):
|
|
197
|
-
name: str
|
|
198
|
-
works_for: Company
|
|
199
|
-
|
|
200
|
-
company1 = Company(name="Figma")
|
|
201
|
-
person1 = Person(name="Steve Rodger", works_for=company1)
|
|
202
|
-
|
|
203
|
-
entities = [company1, person1]
|
|
204
|
-
await add_data_points(entities)
|
|
205
|
-
|
|
206
|
-
retriever = GraphCompletionCotRetriever()
|
|
207
|
-
|
|
208
|
-
# Test with string response model (default)
|
|
209
|
-
string_answer = await retriever.get_structured_completion("Who works at Figma?")
|
|
210
|
-
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
|
|
211
|
-
assert string_answer.strip(), "Answer should not be empty"
|
|
212
|
-
|
|
213
|
-
# Test with structured response model
|
|
214
|
-
structured_answer = await retriever.get_structured_completion(
|
|
215
|
-
"Who works at Figma?", response_model=TestAnswer
|
|
216
|
-
)
|
|
217
|
-
assert isinstance(structured_answer, TestAnswer), (
|
|
218
|
-
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
|
219
|
-
)
|
|
220
|
-
assert structured_answer.answer.strip(), "Answer field should not be empty"
|
|
221
|
-
assert structured_answer.explanation.strip(), "Explanation field should not be empty"
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
import cognee
|
|
5
|
+
import pathlib
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
from cognee.low_level import setup, DataPoint
|
|
10
|
+
from cognee.tasks.storage import add_data_points
|
|
11
|
+
from cognee.modules.chunking.models import DocumentChunk
|
|
12
|
+
from cognee.modules.data.processing.document_types import TextDocument
|
|
13
|
+
from cognee.modules.engine.models import Entity, EntityType
|
|
14
|
+
from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor
|
|
15
|
+
from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider
|
|
16
|
+
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
|
17
|
+
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
18
|
+
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
|
19
|
+
GraphCompletionContextExtensionRetriever,
|
|
20
|
+
)
|
|
21
|
+
from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever
|
|
22
|
+
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
|
23
|
+
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TestAnswer(BaseModel):
|
|
27
|
+
answer: str
|
|
28
|
+
explanation: str
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _assert_string_answer(answer: list[str]):
|
|
32
|
+
assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}"
|
|
33
|
+
assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings"
|
|
34
|
+
assert all(item.strip() for item in answer), "Items should not be empty"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _assert_structured_answer(answer: list[TestAnswer]):
|
|
38
|
+
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
|
39
|
+
assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer"
|
|
40
|
+
assert all(x.answer.strip() for x in answer), "Answer text should not be empty"
|
|
41
|
+
assert all(x.explanation.strip() for x in answer), "Explanation should not be empty"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
async def _test_get_structured_graph_completion_cot():
|
|
45
|
+
retriever = GraphCompletionCotRetriever()
|
|
46
|
+
|
|
47
|
+
# Test with string response model (default)
|
|
48
|
+
string_answer = await retriever.get_completion("Who works at Figma?")
|
|
49
|
+
_assert_string_answer(string_answer)
|
|
50
|
+
|
|
51
|
+
# Test with structured response model
|
|
52
|
+
structured_answer = await retriever.get_completion(
|
|
53
|
+
"Who works at Figma?", response_model=TestAnswer
|
|
54
|
+
)
|
|
55
|
+
_assert_structured_answer(structured_answer)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
async def _test_get_structured_graph_completion():
|
|
59
|
+
retriever = GraphCompletionRetriever()
|
|
60
|
+
|
|
61
|
+
# Test with string response model (default)
|
|
62
|
+
string_answer = await retriever.get_completion("Who works at Figma?")
|
|
63
|
+
_assert_string_answer(string_answer)
|
|
64
|
+
|
|
65
|
+
# Test with structured response model
|
|
66
|
+
structured_answer = await retriever.get_completion(
|
|
67
|
+
"Who works at Figma?", response_model=TestAnswer
|
|
68
|
+
)
|
|
69
|
+
_assert_structured_answer(structured_answer)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
async def _test_get_structured_graph_completion_temporal():
|
|
73
|
+
retriever = TemporalRetriever()
|
|
74
|
+
|
|
75
|
+
# Test with string response model (default)
|
|
76
|
+
string_answer = await retriever.get_completion("When did Steve start working at Figma?")
|
|
77
|
+
_assert_string_answer(string_answer)
|
|
78
|
+
|
|
79
|
+
# Test with structured response model
|
|
80
|
+
structured_answer = await retriever.get_completion(
|
|
81
|
+
"When did Steve start working at Figma??", response_model=TestAnswer
|
|
82
|
+
)
|
|
83
|
+
_assert_structured_answer(structured_answer)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
async def _test_get_structured_graph_completion_rag():
|
|
87
|
+
retriever = CompletionRetriever()
|
|
88
|
+
|
|
89
|
+
# Test with string response model (default)
|
|
90
|
+
string_answer = await retriever.get_completion("Where does Steve work?")
|
|
91
|
+
_assert_string_answer(string_answer)
|
|
92
|
+
|
|
93
|
+
# Test with structured response model
|
|
94
|
+
structured_answer = await retriever.get_completion(
|
|
95
|
+
"Where does Steve work?", response_model=TestAnswer
|
|
96
|
+
)
|
|
97
|
+
_assert_structured_answer(structured_answer)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
async def _test_get_structured_graph_completion_context_extension():
|
|
101
|
+
retriever = GraphCompletionContextExtensionRetriever()
|
|
102
|
+
|
|
103
|
+
# Test with string response model (default)
|
|
104
|
+
string_answer = await retriever.get_completion("Who works at Figma?")
|
|
105
|
+
_assert_string_answer(string_answer)
|
|
106
|
+
|
|
107
|
+
# Test with structured response model
|
|
108
|
+
structured_answer = await retriever.get_completion(
|
|
109
|
+
"Who works at Figma?", response_model=TestAnswer
|
|
110
|
+
)
|
|
111
|
+
_assert_structured_answer(structured_answer)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
async def _test_get_structured_entity_completion():
|
|
115
|
+
retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider())
|
|
116
|
+
|
|
117
|
+
# Test with string response model (default)
|
|
118
|
+
string_answer = await retriever.get_completion("Who is Albert Einstein?")
|
|
119
|
+
_assert_string_answer(string_answer)
|
|
120
|
+
|
|
121
|
+
# Test with structured response model
|
|
122
|
+
structured_answer = await retriever.get_completion(
|
|
123
|
+
"Who is Albert Einstein?", response_model=TestAnswer
|
|
124
|
+
)
|
|
125
|
+
_assert_structured_answer(structured_answer)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class TestStructuredOutputCompletion:
|
|
129
|
+
@pytest.mark.asyncio
|
|
130
|
+
async def test_get_structured_completion(self):
|
|
131
|
+
system_directory_path = os.path.join(
|
|
132
|
+
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
|
|
133
|
+
)
|
|
134
|
+
cognee.config.system_root_directory(system_directory_path)
|
|
135
|
+
data_directory_path = os.path.join(
|
|
136
|
+
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
|
|
137
|
+
)
|
|
138
|
+
cognee.config.data_root_directory(data_directory_path)
|
|
139
|
+
|
|
140
|
+
await cognee.prune.prune_data()
|
|
141
|
+
await cognee.prune.prune_system(metadata=True)
|
|
142
|
+
await setup()
|
|
143
|
+
|
|
144
|
+
class Company(DataPoint):
|
|
145
|
+
name: str
|
|
146
|
+
|
|
147
|
+
class Person(DataPoint):
|
|
148
|
+
name: str
|
|
149
|
+
works_for: Company
|
|
150
|
+
works_since: int
|
|
151
|
+
|
|
152
|
+
company1 = Company(name="Figma")
|
|
153
|
+
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
|
|
154
|
+
|
|
155
|
+
entities = [company1, person1]
|
|
156
|
+
await add_data_points(entities)
|
|
157
|
+
|
|
158
|
+
document = TextDocument(
|
|
159
|
+
name="Steve Rodger's career",
|
|
160
|
+
raw_data_location="somewhere",
|
|
161
|
+
external_metadata="",
|
|
162
|
+
mime_type="text/plain",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
chunk1 = DocumentChunk(
|
|
166
|
+
text="Steve Rodger",
|
|
167
|
+
chunk_size=2,
|
|
168
|
+
chunk_index=0,
|
|
169
|
+
cut_type="sentence_end",
|
|
170
|
+
is_part_of=document,
|
|
171
|
+
contains=[],
|
|
172
|
+
)
|
|
173
|
+
chunk2 = DocumentChunk(
|
|
174
|
+
text="Mike Broski",
|
|
175
|
+
chunk_size=2,
|
|
176
|
+
chunk_index=1,
|
|
177
|
+
cut_type="sentence_end",
|
|
178
|
+
is_part_of=document,
|
|
179
|
+
contains=[],
|
|
180
|
+
)
|
|
181
|
+
chunk3 = DocumentChunk(
|
|
182
|
+
text="Christina Mayer",
|
|
183
|
+
chunk_size=2,
|
|
184
|
+
chunk_index=2,
|
|
185
|
+
cut_type="sentence_end",
|
|
186
|
+
is_part_of=document,
|
|
187
|
+
contains=[],
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
entities = [chunk1, chunk2, chunk3]
|
|
191
|
+
await add_data_points(entities)
|
|
192
|
+
|
|
193
|
+
entity_type = EntityType(name="Person", description="A human individual")
|
|
194
|
+
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
|
|
195
|
+
|
|
196
|
+
entities = [entity]
|
|
197
|
+
await add_data_points(entities)
|
|
198
|
+
|
|
199
|
+
await _test_get_structured_graph_completion_cot()
|
|
200
|
+
await _test_get_structured_graph_completion()
|
|
201
|
+
await _test_get_structured_graph_completion_temporal()
|
|
202
|
+
await _test_get_structured_graph_completion_rag()
|
|
203
|
+
await _test_get_structured_graph_completion_context_extension()
|
|
204
|
+
await _test_get_structured_entity_completion()
|
|
@@ -13,7 +13,7 @@ from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
|
13
13
|
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
class
|
|
16
|
+
class TestSummariesRetriever:
|
|
17
17
|
@pytest.mark.asyncio
|
|
18
18
|
async def test_chunk_context(self):
|
|
19
19
|
system_directory_path = os.path.join(
|
|
@@ -107,29 +107,10 @@ class TestConditionalAuthenticationIntegration:
|
|
|
107
107
|
# REQUIRE_AUTHENTICATION should be a boolean
|
|
108
108
|
assert isinstance(REQUIRE_AUTHENTICATION, bool)
|
|
109
109
|
|
|
110
|
-
# Currently should be False (optional authentication)
|
|
111
|
-
assert not REQUIRE_AUTHENTICATION
|
|
112
|
-
|
|
113
110
|
|
|
114
111
|
class TestConditionalAuthenticationEnvironmentVariables:
|
|
115
112
|
"""Test environment variable handling."""
|
|
116
113
|
|
|
117
|
-
def test_require_authentication_default_false(self):
|
|
118
|
-
"""Test that REQUIRE_AUTHENTICATION defaults to false when imported with no env vars."""
|
|
119
|
-
with patch.dict(os.environ, {}, clear=True):
|
|
120
|
-
# Remove module from cache to force fresh import
|
|
121
|
-
module_name = "cognee.modules.users.methods.get_authenticated_user"
|
|
122
|
-
if module_name in sys.modules:
|
|
123
|
-
del sys.modules[module_name]
|
|
124
|
-
|
|
125
|
-
# Import after patching environment - module will see empty environment
|
|
126
|
-
from cognee.modules.users.methods.get_authenticated_user import (
|
|
127
|
-
REQUIRE_AUTHENTICATION,
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
importlib.invalidate_caches()
|
|
131
|
-
assert not REQUIRE_AUTHENTICATION
|
|
132
|
-
|
|
133
114
|
def test_require_authentication_true(self):
|
|
134
115
|
"""Test that REQUIRE_AUTHENTICATION=true is parsed correctly when imported."""
|
|
135
116
|
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "true"}):
|
|
@@ -145,50 +126,6 @@ class TestConditionalAuthenticationEnvironmentVariables:
|
|
|
145
126
|
|
|
146
127
|
assert REQUIRE_AUTHENTICATION
|
|
147
128
|
|
|
148
|
-
def test_require_authentication_false_explicit(self):
|
|
149
|
-
"""Test that REQUIRE_AUTHENTICATION=false is parsed correctly when imported."""
|
|
150
|
-
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "false"}):
|
|
151
|
-
# Remove module from cache to force fresh import
|
|
152
|
-
module_name = "cognee.modules.users.methods.get_authenticated_user"
|
|
153
|
-
if module_name in sys.modules:
|
|
154
|
-
del sys.modules[module_name]
|
|
155
|
-
|
|
156
|
-
# Import after patching environment - module will see REQUIRE_AUTHENTICATION=false
|
|
157
|
-
from cognee.modules.users.methods.get_authenticated_user import (
|
|
158
|
-
REQUIRE_AUTHENTICATION,
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
assert not REQUIRE_AUTHENTICATION
|
|
162
|
-
|
|
163
|
-
def test_require_authentication_case_insensitive(self):
|
|
164
|
-
"""Test that environment variable parsing is case insensitive when imported."""
|
|
165
|
-
test_cases = ["TRUE", "True", "tRuE", "FALSE", "False", "fAlSe"]
|
|
166
|
-
|
|
167
|
-
for case in test_cases:
|
|
168
|
-
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": case}):
|
|
169
|
-
# Remove module from cache to force fresh import
|
|
170
|
-
module_name = "cognee.modules.users.methods.get_authenticated_user"
|
|
171
|
-
if module_name in sys.modules:
|
|
172
|
-
del sys.modules[module_name]
|
|
173
|
-
|
|
174
|
-
# Import after patching environment
|
|
175
|
-
from cognee.modules.users.methods.get_authenticated_user import (
|
|
176
|
-
REQUIRE_AUTHENTICATION,
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
expected = case.lower() == "true"
|
|
180
|
-
assert REQUIRE_AUTHENTICATION == expected, f"Failed for case: {case}"
|
|
181
|
-
|
|
182
|
-
def test_current_require_authentication_value(self):
|
|
183
|
-
"""Test that the current REQUIRE_AUTHENTICATION module value is as expected."""
|
|
184
|
-
from cognee.modules.users.methods.get_authenticated_user import (
|
|
185
|
-
REQUIRE_AUTHENTICATION,
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
# The module-level variable should currently be False (set at import time)
|
|
189
|
-
assert isinstance(REQUIRE_AUTHENTICATION, bool)
|
|
190
|
-
assert not REQUIRE_AUTHENTICATION
|
|
191
|
-
|
|
192
129
|
|
|
193
130
|
class TestConditionalAuthenticationEdgeCases:
|
|
194
131
|
"""Test edge cases and error scenarios."""
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from itertools import product
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
|
7
|
+
from cognee.tasks.chunks import chunk_by_row
|
|
8
|
+
|
|
9
|
+
INPUT_TEXTS = "name: John, age: 30, city: New York, country: USA"
|
|
10
|
+
max_chunk_size_vals = [8, 32]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.mark.parametrize(
|
|
14
|
+
"input_text,max_chunk_size",
|
|
15
|
+
list(product([INPUT_TEXTS], max_chunk_size_vals)),
|
|
16
|
+
)
|
|
17
|
+
def test_chunk_by_row_isomorphism(input_text, max_chunk_size):
|
|
18
|
+
chunks = chunk_by_row(input_text, max_chunk_size)
|
|
19
|
+
reconstructed_text = ", ".join([chunk["text"] for chunk in chunks])
|
|
20
|
+
assert reconstructed_text == input_text, (
|
|
21
|
+
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.mark.parametrize(
|
|
26
|
+
"input_text,max_chunk_size",
|
|
27
|
+
list(product([INPUT_TEXTS], max_chunk_size_vals)),
|
|
28
|
+
)
|
|
29
|
+
def test_row_chunk_length(input_text, max_chunk_size):
|
|
30
|
+
chunks = list(chunk_by_row(data=input_text, max_chunk_size=max_chunk_size))
|
|
31
|
+
embedding_engine = get_embedding_engine()
|
|
32
|
+
|
|
33
|
+
chunk_lengths = np.array(
|
|
34
|
+
[embedding_engine.tokenizer.count_tokens(chunk["text"]) for chunk in chunks]
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
|
|
38
|
+
assert np.all(chunk_lengths <= max_chunk_size), (
|
|
39
|
+
f"{max_chunk_size = }: {larger_chunks} are too large"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@pytest.mark.parametrize(
|
|
44
|
+
"input_text,max_chunk_size",
|
|
45
|
+
list(product([INPUT_TEXTS], max_chunk_size_vals)),
|
|
46
|
+
)
|
|
47
|
+
def test_chunk_by_row_chunk_numbering(input_text, max_chunk_size):
|
|
48
|
+
chunks = chunk_by_row(data=input_text, max_chunk_size=max_chunk_size)
|
|
49
|
+
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
|
50
|
+
assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
|
|
51
|
+
f"{chunk_indices = } are not monotonically increasing"
|
|
52
|
+
)
|