cognee 0.5.0.dev1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +1 -0
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/data/models/Data.py +2 -1
- cognee/modules/retrieval/triplet_retriever.py +1 -1
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
- cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
- cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_search_db.py +334 -181
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +56 -42
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,7 +1,12 @@
|
|
|
1
1
|
from types import SimpleNamespace
|
|
2
2
|
import pytest
|
|
3
|
+
import os
|
|
4
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
5
|
+
from datetime import datetime
|
|
3
6
|
|
|
4
7
|
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
|
8
|
+
from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp
|
|
9
|
+
from cognee.infrastructure.llm import LLMGateway
|
|
5
10
|
|
|
6
11
|
|
|
7
12
|
# Test TemporalRetriever initialization defaults and overrides
|
|
@@ -140,85 +145,561 @@ async def test_filter_top_k_events_error_handling():
|
|
|
140
145
|
await tr.filter_top_k_events([{}], [])
|
|
141
146
|
|
|
142
147
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
148
|
+
@pytest.fixture
|
|
149
|
+
def mock_graph_engine():
|
|
150
|
+
"""Create a mock graph engine."""
|
|
151
|
+
engine = AsyncMock()
|
|
152
|
+
engine.collect_time_ids = AsyncMock()
|
|
153
|
+
engine.collect_events = AsyncMock()
|
|
154
|
+
return engine
|
|
147
155
|
|
|
148
|
-
|
|
149
|
-
|
|
156
|
+
|
|
157
|
+
@pytest.fixture
|
|
158
|
+
def mock_vector_engine():
|
|
159
|
+
"""Create a mock vector engine."""
|
|
160
|
+
engine = AsyncMock()
|
|
161
|
+
engine.embedding_engine = AsyncMock()
|
|
162
|
+
engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
163
|
+
engine.search = AsyncMock()
|
|
164
|
+
return engine
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@pytest.mark.asyncio
|
|
168
|
+
async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine):
|
|
169
|
+
"""Test get_context when time range is extracted from query."""
|
|
170
|
+
retriever = TemporalRetriever(top_k=5)
|
|
171
|
+
|
|
172
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"]
|
|
173
|
+
mock_graph_engine.collect_events.return_value = [
|
|
174
|
+
{
|
|
175
|
+
"events": [
|
|
176
|
+
{"id": "e1", "description": "Event 1"},
|
|
177
|
+
{"id": "e2", "description": "Event 2"},
|
|
178
|
+
]
|
|
179
|
+
}
|
|
180
|
+
]
|
|
181
|
+
|
|
182
|
+
mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05)
|
|
183
|
+
mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10)
|
|
184
|
+
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
|
185
|
+
|
|
186
|
+
with (
|
|
187
|
+
patch.object(
|
|
188
|
+
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
|
189
|
+
),
|
|
190
|
+
patch(
|
|
191
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
192
|
+
return_value=mock_graph_engine,
|
|
193
|
+
),
|
|
194
|
+
patch(
|
|
195
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
196
|
+
return_value=mock_vector_engine,
|
|
197
|
+
),
|
|
198
|
+
):
|
|
199
|
+
context = await retriever.get_context("What happened in 2024?")
|
|
200
|
+
|
|
201
|
+
assert isinstance(context, str)
|
|
202
|
+
assert len(context) > 0
|
|
203
|
+
assert "Event" in context
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@pytest.mark.asyncio
|
|
207
|
+
async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine):
|
|
208
|
+
"""Test get_context falls back to triplets when no time is extracted."""
|
|
209
|
+
retriever = TemporalRetriever()
|
|
210
|
+
|
|
211
|
+
with (
|
|
212
|
+
patch(
|
|
213
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
214
|
+
return_value=mock_graph_engine,
|
|
215
|
+
),
|
|
216
|
+
patch.object(
|
|
217
|
+
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
|
|
218
|
+
) as mock_get_triplets,
|
|
219
|
+
patch.object(
|
|
220
|
+
retriever, "resolve_edges_to_text", return_value="triplet text"
|
|
221
|
+
) as mock_resolve,
|
|
222
|
+
):
|
|
223
|
+
|
|
224
|
+
async def mock_extract_time(query):
|
|
225
|
+
return None, None
|
|
226
|
+
|
|
227
|
+
retriever.extract_time_from_query = mock_extract_time
|
|
228
|
+
|
|
229
|
+
context = await retriever.get_context("test query")
|
|
230
|
+
|
|
231
|
+
assert context == "triplet text"
|
|
232
|
+
mock_get_triplets.assert_awaited_once_with("test query")
|
|
233
|
+
mock_resolve.assert_awaited_once()
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@pytest.mark.asyncio
|
|
237
|
+
async def test_get_context_no_events_found(mock_graph_engine):
|
|
238
|
+
"""Test get_context falls back to triplets when no events are found."""
|
|
239
|
+
retriever = TemporalRetriever()
|
|
240
|
+
|
|
241
|
+
mock_graph_engine.collect_time_ids.return_value = []
|
|
242
|
+
|
|
243
|
+
with (
|
|
244
|
+
patch(
|
|
245
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
246
|
+
return_value=mock_graph_engine,
|
|
247
|
+
),
|
|
248
|
+
patch.object(
|
|
249
|
+
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
|
|
250
|
+
) as mock_get_triplets,
|
|
251
|
+
patch.object(
|
|
252
|
+
retriever, "resolve_edges_to_text", return_value="triplet text"
|
|
253
|
+
) as mock_resolve,
|
|
254
|
+
):
|
|
255
|
+
|
|
256
|
+
async def mock_extract_time(query):
|
|
150
257
|
return "2024-01-01", "2024-12-31"
|
|
151
|
-
if "from_only" in query:
|
|
152
|
-
return "2024-01-01", None
|
|
153
|
-
if "to_only" in query:
|
|
154
|
-
return None, "2024-12-31"
|
|
155
|
-
return None, None
|
|
156
258
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
259
|
+
retriever.extract_time_from_query = mock_extract_time
|
|
260
|
+
|
|
261
|
+
context = await retriever.get_context("test query")
|
|
262
|
+
|
|
263
|
+
assert context == "triplet text"
|
|
264
|
+
mock_get_triplets.assert_awaited_once_with("test query")
|
|
265
|
+
mock_resolve.assert_awaited_once()
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@pytest.mark.asyncio
|
|
269
|
+
async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine):
|
|
270
|
+
"""Test get_context with only time_from."""
|
|
271
|
+
retriever = TemporalRetriever(top_k=5)
|
|
272
|
+
|
|
273
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
|
274
|
+
mock_graph_engine.collect_events.return_value = [
|
|
275
|
+
{
|
|
276
|
+
"events": [
|
|
277
|
+
{"id": "e1", "description": "Event 1"},
|
|
278
|
+
]
|
|
279
|
+
}
|
|
280
|
+
]
|
|
281
|
+
|
|
282
|
+
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
|
283
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
284
|
+
|
|
285
|
+
with (
|
|
286
|
+
patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)),
|
|
287
|
+
patch(
|
|
288
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
289
|
+
return_value=mock_graph_engine,
|
|
290
|
+
),
|
|
291
|
+
patch(
|
|
292
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
293
|
+
return_value=mock_vector_engine,
|
|
294
|
+
),
|
|
295
|
+
):
|
|
296
|
+
context = await retriever.get_context("What happened after 2024?")
|
|
297
|
+
|
|
298
|
+
assert isinstance(context, str)
|
|
299
|
+
assert "Event 1" in context
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
@pytest.mark.asyncio
|
|
303
|
+
async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine):
|
|
304
|
+
"""Test get_context with only time_to."""
|
|
305
|
+
retriever = TemporalRetriever(top_k=5)
|
|
306
|
+
|
|
307
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
|
308
|
+
mock_graph_engine.collect_events.return_value = [
|
|
309
|
+
{
|
|
310
|
+
"events": [
|
|
311
|
+
{"id": "e1", "description": "Event 1"},
|
|
312
|
+
]
|
|
313
|
+
}
|
|
314
|
+
]
|
|
315
|
+
|
|
316
|
+
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
|
317
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
318
|
+
|
|
319
|
+
with (
|
|
320
|
+
patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")),
|
|
321
|
+
patch(
|
|
322
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
323
|
+
return_value=mock_graph_engine,
|
|
324
|
+
),
|
|
325
|
+
patch(
|
|
326
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
327
|
+
return_value=mock_vector_engine,
|
|
328
|
+
),
|
|
329
|
+
):
|
|
330
|
+
context = await retriever.get_context("What happened before 2024?")
|
|
331
|
+
|
|
332
|
+
assert isinstance(context, str)
|
|
333
|
+
assert "Event 1" in context
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
@pytest.mark.asyncio
|
|
337
|
+
async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine):
|
|
338
|
+
"""Test get_completion retrieves context when not provided."""
|
|
339
|
+
retriever = TemporalRetriever()
|
|
340
|
+
|
|
341
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
|
342
|
+
mock_graph_engine.collect_events.return_value = [
|
|
343
|
+
{
|
|
344
|
+
"events": [
|
|
345
|
+
{"id": "e1", "description": "Event 1"},
|
|
346
|
+
]
|
|
347
|
+
}
|
|
348
|
+
]
|
|
349
|
+
|
|
350
|
+
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
|
351
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
352
|
+
|
|
353
|
+
with (
|
|
354
|
+
patch.object(
|
|
355
|
+
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
|
356
|
+
),
|
|
357
|
+
patch(
|
|
358
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
359
|
+
return_value=mock_graph_engine,
|
|
360
|
+
),
|
|
361
|
+
patch(
|
|
362
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
363
|
+
return_value=mock_vector_engine,
|
|
364
|
+
),
|
|
365
|
+
patch(
|
|
366
|
+
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
|
367
|
+
return_value="Generated answer",
|
|
368
|
+
),
|
|
369
|
+
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
|
370
|
+
):
|
|
371
|
+
mock_config = MagicMock()
|
|
372
|
+
mock_config.caching = False
|
|
373
|
+
mock_cache_config.return_value = mock_config
|
|
374
|
+
|
|
375
|
+
completion = await retriever.get_completion("What happened in 2024?")
|
|
376
|
+
|
|
377
|
+
assert isinstance(completion, list)
|
|
378
|
+
assert len(completion) == 1
|
|
379
|
+
assert completion[0] == "Generated answer"
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
@pytest.mark.asyncio
|
|
383
|
+
async def test_get_completion_with_provided_context():
|
|
384
|
+
"""Test get_completion uses provided context."""
|
|
385
|
+
retriever = TemporalRetriever()
|
|
386
|
+
|
|
387
|
+
with (
|
|
388
|
+
patch(
|
|
389
|
+
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
|
390
|
+
return_value="Generated answer",
|
|
391
|
+
),
|
|
392
|
+
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
|
393
|
+
):
|
|
394
|
+
mock_config = MagicMock()
|
|
395
|
+
mock_config.caching = False
|
|
396
|
+
mock_cache_config.return_value = mock_config
|
|
397
|
+
|
|
398
|
+
completion = await retriever.get_completion("test query", context="Provided context")
|
|
399
|
+
|
|
400
|
+
assert isinstance(completion, list)
|
|
401
|
+
assert len(completion) == 1
|
|
402
|
+
assert completion[0] == "Generated answer"
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
@pytest.mark.asyncio
|
|
406
|
+
async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine):
|
|
407
|
+
"""Test get_completion with session caching enabled."""
|
|
408
|
+
retriever = TemporalRetriever()
|
|
409
|
+
|
|
410
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
|
411
|
+
mock_graph_engine.collect_events.return_value = [
|
|
412
|
+
{
|
|
413
|
+
"events": [
|
|
414
|
+
{"id": "e1", "description": "Event 1"},
|
|
415
|
+
]
|
|
416
|
+
}
|
|
417
|
+
]
|
|
418
|
+
|
|
419
|
+
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
|
420
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
421
|
+
|
|
422
|
+
mock_user = MagicMock()
|
|
423
|
+
mock_user.id = "test-user-id"
|
|
424
|
+
|
|
425
|
+
with (
|
|
426
|
+
patch.object(
|
|
427
|
+
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
|
428
|
+
),
|
|
429
|
+
patch(
|
|
430
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
431
|
+
return_value=mock_graph_engine,
|
|
432
|
+
),
|
|
433
|
+
patch(
|
|
434
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
435
|
+
return_value=mock_vector_engine,
|
|
436
|
+
),
|
|
437
|
+
patch(
|
|
438
|
+
"cognee.modules.retrieval.temporal_retriever.get_conversation_history",
|
|
439
|
+
return_value="Previous conversation",
|
|
440
|
+
),
|
|
441
|
+
patch(
|
|
442
|
+
"cognee.modules.retrieval.temporal_retriever.summarize_text",
|
|
443
|
+
return_value="Context summary",
|
|
444
|
+
),
|
|
445
|
+
patch(
|
|
446
|
+
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
|
447
|
+
return_value="Generated answer",
|
|
448
|
+
),
|
|
449
|
+
patch(
|
|
450
|
+
"cognee.modules.retrieval.temporal_retriever.save_conversation_history",
|
|
451
|
+
) as mock_save,
|
|
452
|
+
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
|
453
|
+
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
|
|
454
|
+
):
|
|
455
|
+
mock_config = MagicMock()
|
|
456
|
+
mock_config.caching = True
|
|
457
|
+
mock_cache_config.return_value = mock_config
|
|
458
|
+
mock_session_user.get.return_value = mock_user
|
|
459
|
+
|
|
460
|
+
completion = await retriever.get_completion(
|
|
461
|
+
"What happened in 2024?", session_id="test_session"
|
|
462
|
+
)
|
|
160
463
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
464
|
+
assert isinstance(completion, list)
|
|
465
|
+
assert len(completion) == 1
|
|
466
|
+
assert completion[0] == "Generated answer"
|
|
467
|
+
mock_save.assert_awaited_once()
|
|
164
468
|
|
|
165
|
-
async def _fake_graph_collect_ids(self, **kwargs):
|
|
166
|
-
return ["e1", "e2"]
|
|
167
469
|
|
|
168
|
-
|
|
169
|
-
|
|
470
|
+
@pytest.mark.asyncio
|
|
471
|
+
async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine):
|
|
472
|
+
"""Test get_completion with session config but no user ID."""
|
|
473
|
+
retriever = TemporalRetriever()
|
|
474
|
+
|
|
475
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
|
476
|
+
mock_graph_engine.collect_events.return_value = [
|
|
477
|
+
{
|
|
478
|
+
"events": [
|
|
479
|
+
{"id": "e1", "description": "Event 1"},
|
|
480
|
+
]
|
|
481
|
+
}
|
|
482
|
+
]
|
|
483
|
+
|
|
484
|
+
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
|
485
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
486
|
+
|
|
487
|
+
with (
|
|
488
|
+
patch.object(
|
|
489
|
+
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
|
490
|
+
),
|
|
491
|
+
patch(
|
|
492
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
493
|
+
return_value=mock_graph_engine,
|
|
494
|
+
),
|
|
495
|
+
patch(
|
|
496
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
497
|
+
return_value=mock_vector_engine,
|
|
498
|
+
),
|
|
499
|
+
patch(
|
|
500
|
+
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
|
501
|
+
return_value="Generated answer",
|
|
502
|
+
),
|
|
503
|
+
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
|
504
|
+
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
|
|
505
|
+
):
|
|
506
|
+
mock_config = MagicMock()
|
|
507
|
+
mock_config.caching = True
|
|
508
|
+
mock_cache_config.return_value = mock_config
|
|
509
|
+
mock_session_user.get.return_value = None # No user
|
|
510
|
+
|
|
511
|
+
completion = await retriever.get_completion("What happened in 2024?")
|
|
512
|
+
|
|
513
|
+
assert isinstance(completion, list)
|
|
514
|
+
assert len(completion) == 1
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
@pytest.mark.asyncio
|
|
518
|
+
async def test_get_completion_context_retrieved_but_empty(mock_graph_engine):
|
|
519
|
+
"""Test get_completion when get_context returns empty string."""
|
|
520
|
+
retriever = TemporalRetriever()
|
|
521
|
+
|
|
522
|
+
with (
|
|
523
|
+
patch.object(
|
|
524
|
+
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
|
525
|
+
),
|
|
526
|
+
patch(
|
|
527
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
528
|
+
return_value=mock_graph_engine,
|
|
529
|
+
),
|
|
530
|
+
patch(
|
|
531
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
532
|
+
) as mock_get_vector,
|
|
533
|
+
patch.object(retriever, "filter_top_k_events", return_value=[]),
|
|
534
|
+
):
|
|
535
|
+
mock_vector_engine = AsyncMock()
|
|
536
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
537
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
538
|
+
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
539
|
+
mock_get_vector.return_value = mock_vector_engine
|
|
540
|
+
|
|
541
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
|
542
|
+
mock_graph_engine.collect_events.return_value = [
|
|
170
543
|
{
|
|
171
544
|
"events": [
|
|
172
|
-
{"id": "e1", "description": "
|
|
173
|
-
{"id": "e2", "description": "E2"},
|
|
174
|
-
{"id": "e3", "description": "E3"},
|
|
545
|
+
{"id": "e1", "description": ""},
|
|
175
546
|
]
|
|
176
547
|
}
|
|
177
548
|
]
|
|
178
549
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
return [[0.0, 1.0, 2.0]]
|
|
550
|
+
with pytest.raises((UnboundLocalError, NameError)):
|
|
551
|
+
await retriever.get_completion("test query")
|
|
182
552
|
|
|
183
|
-
async def _fake_vector_search(self, **kwargs):
|
|
184
|
-
return [
|
|
185
|
-
SimpleNamespace(payload={"id": "e2"}, score=0.05),
|
|
186
|
-
SimpleNamespace(payload={"id": "e1"}, score=0.10),
|
|
187
|
-
]
|
|
188
553
|
|
|
189
|
-
|
|
190
|
-
|
|
554
|
+
@pytest.mark.asyncio
|
|
555
|
+
async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine):
|
|
556
|
+
"""Test get_completion with custom response model."""
|
|
557
|
+
from pydantic import BaseModel
|
|
191
558
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
return await self.resolve_edges_to_text(triplets)
|
|
559
|
+
class TestModel(BaseModel):
|
|
560
|
+
answer: str
|
|
195
561
|
|
|
196
|
-
|
|
197
|
-
relevant_events = await self._fake_graph_collect_events(ids)
|
|
562
|
+
retriever = TemporalRetriever()
|
|
198
563
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
564
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
|
565
|
+
mock_graph_engine.collect_events.return_value = [
|
|
566
|
+
{
|
|
567
|
+
"events": [
|
|
568
|
+
{"id": "e1", "description": "Event 1"},
|
|
569
|
+
]
|
|
570
|
+
}
|
|
571
|
+
]
|
|
572
|
+
|
|
573
|
+
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
|
574
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
575
|
+
|
|
576
|
+
with (
|
|
577
|
+
patch.object(
|
|
578
|
+
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
|
579
|
+
),
|
|
580
|
+
patch(
|
|
581
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
582
|
+
return_value=mock_graph_engine,
|
|
583
|
+
),
|
|
584
|
+
patch(
|
|
585
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
586
|
+
return_value=mock_vector_engine,
|
|
587
|
+
),
|
|
588
|
+
patch(
|
|
589
|
+
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
|
590
|
+
return_value=TestModel(answer="Test answer"),
|
|
591
|
+
),
|
|
592
|
+
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
|
593
|
+
):
|
|
594
|
+
mock_config = MagicMock()
|
|
595
|
+
mock_config.caching = False
|
|
596
|
+
mock_cache_config.return_value = mock_config
|
|
597
|
+
|
|
598
|
+
completion = await retriever.get_completion(
|
|
599
|
+
"What happened in 2024?", response_model=TestModel
|
|
202
600
|
)
|
|
203
|
-
|
|
204
|
-
|
|
601
|
+
|
|
602
|
+
assert isinstance(completion, list)
|
|
603
|
+
assert len(completion) == 1
|
|
604
|
+
assert isinstance(completion[0], TestModel)
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
@pytest.mark.asyncio
|
|
608
|
+
async def test_extract_time_from_query_relative_path():
|
|
609
|
+
"""Test extract_time_from_query with relative prompt path."""
|
|
610
|
+
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
|
|
611
|
+
|
|
612
|
+
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
|
|
613
|
+
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
|
|
614
|
+
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
|
|
615
|
+
|
|
616
|
+
with (
|
|
617
|
+
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
|
|
618
|
+
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
|
619
|
+
patch(
|
|
620
|
+
"cognee.modules.retrieval.temporal_retriever.render_prompt",
|
|
621
|
+
return_value="System prompt",
|
|
622
|
+
),
|
|
623
|
+
patch.object(
|
|
624
|
+
LLMGateway,
|
|
625
|
+
"acreate_structured_output",
|
|
626
|
+
new_callable=AsyncMock,
|
|
627
|
+
return_value=mock_interval,
|
|
628
|
+
),
|
|
629
|
+
):
|
|
630
|
+
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
|
631
|
+
|
|
632
|
+
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
|
|
633
|
+
|
|
634
|
+
assert time_from == mock_timestamp_from
|
|
635
|
+
assert time_to == mock_timestamp_to
|
|
205
636
|
|
|
206
637
|
|
|
207
|
-
# Test get_context fallback to triplets when no time is extracted
|
|
208
638
|
@pytest.mark.asyncio
|
|
209
|
-
async def
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
639
|
+
async def test_extract_time_from_query_absolute_path():
|
|
640
|
+
"""Test extract_time_from_query with absolute prompt path."""
|
|
641
|
+
retriever = TemporalRetriever(
|
|
642
|
+
time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt"
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
|
|
646
|
+
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
|
|
647
|
+
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
|
|
648
|
+
|
|
649
|
+
with (
|
|
650
|
+
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True),
|
|
651
|
+
patch(
|
|
652
|
+
"cognee.modules.retrieval.temporal_retriever.os.path.dirname",
|
|
653
|
+
return_value="/absolute/path/to",
|
|
654
|
+
),
|
|
655
|
+
patch(
|
|
656
|
+
"cognee.modules.retrieval.temporal_retriever.os.path.basename",
|
|
657
|
+
return_value="extract_query_time.txt",
|
|
658
|
+
),
|
|
659
|
+
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
|
660
|
+
patch(
|
|
661
|
+
"cognee.modules.retrieval.temporal_retriever.render_prompt",
|
|
662
|
+
return_value="System prompt",
|
|
663
|
+
),
|
|
664
|
+
patch.object(
|
|
665
|
+
LLMGateway,
|
|
666
|
+
"acreate_structured_output",
|
|
667
|
+
new_callable=AsyncMock,
|
|
668
|
+
return_value=mock_interval,
|
|
669
|
+
),
|
|
670
|
+
):
|
|
671
|
+
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
|
672
|
+
|
|
673
|
+
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
|
|
674
|
+
|
|
675
|
+
assert time_from == mock_timestamp_from
|
|
676
|
+
assert time_to == mock_timestamp_to
|
|
215
677
|
|
|
216
678
|
|
|
217
|
-
# Test get_context when time is extracted and vector ranking is applied
|
|
218
679
|
@pytest.mark.asyncio
|
|
219
|
-
async def
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
680
|
+
async def test_extract_time_from_query_with_none_values():
|
|
681
|
+
"""Test extract_time_from_query when interval has None values."""
|
|
682
|
+
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
|
|
683
|
+
|
|
684
|
+
mock_interval = QueryInterval(starts_at=None, ends_at=None)
|
|
685
|
+
|
|
686
|
+
with (
|
|
687
|
+
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
|
|
688
|
+
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
|
689
|
+
patch(
|
|
690
|
+
"cognee.modules.retrieval.temporal_retriever.render_prompt",
|
|
691
|
+
return_value="System prompt",
|
|
692
|
+
),
|
|
693
|
+
patch.object(
|
|
694
|
+
LLMGateway,
|
|
695
|
+
"acreate_structured_output",
|
|
696
|
+
new_callable=AsyncMock,
|
|
697
|
+
return_value=mock_interval,
|
|
698
|
+
),
|
|
699
|
+
):
|
|
700
|
+
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
|
701
|
+
|
|
702
|
+
time_from, time_to = await retriever.extract_time_from_query("What happened?")
|
|
703
|
+
|
|
704
|
+
assert time_from is None
|
|
705
|
+
assert time_to is None
|