cognee 0.5.1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. cognee/api/v1/add/add.py +2 -1
  2. cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
  3. cognee/api/v1/memify/routers/get_memify_router.py +1 -0
  4. cognee/api/v1/search/search.py +0 -4
  5. cognee/infrastructure/databases/relational/config.py +16 -1
  6. cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
  7. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
  8. cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
  9. cognee/infrastructure/llm/LLMGateway.py +0 -13
  10. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
  11. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
  12. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
  13. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
  14. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
  15. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
  16. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
  17. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
  18. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
  19. cognee/modules/data/models/Data.py +2 -1
  20. cognee/modules/retrieval/triplet_retriever.py +1 -1
  21. cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
  22. cognee/modules/search/methods/search.py +18 -25
  23. cognee/tasks/ingestion/data_item.py +8 -0
  24. cognee/tasks/ingestion/ingest_data.py +12 -1
  25. cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
  26. cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
  27. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
  28. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
  29. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
  30. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
  31. cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
  32. cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
  33. cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
  34. cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
  35. cognee/tests/test_custom_data_label.py +68 -0
  36. cognee/tests/test_search_db.py +334 -181
  37. cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
  38. cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
  39. cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
  40. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
  41. cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
  42. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
  43. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
  44. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
  45. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
  46. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
  47. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
  48. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
  49. cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
  50. cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
  51. cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
  52. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
  53. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
  54. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +58 -45
  55. cognee/tests/unit/modules/search/test_search.py +0 -100
  56. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
  57. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
  58. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
  59. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,7 +1,12 @@
1
1
  from types import SimpleNamespace
2
2
  import pytest
3
+ import os
4
+ from unittest.mock import AsyncMock, patch, MagicMock
5
+ from datetime import datetime
3
6
 
4
7
  from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
8
+ from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp
9
+ from cognee.infrastructure.llm import LLMGateway
5
10
 
6
11
 
7
12
  # Test TemporalRetriever initialization defaults and overrides
@@ -140,85 +145,561 @@ async def test_filter_top_k_events_error_handling():
140
145
  await tr.filter_top_k_events([{}], [])
141
146
 
142
147
 
143
- class _FakeRetriever(TemporalRetriever):
144
- def __init__(self, *args, **kwargs):
145
- super().__init__(*args, **kwargs)
146
- self._calls = []
148
+ @pytest.fixture
149
+ def mock_graph_engine():
150
+ """Create a mock graph engine."""
151
+ engine = AsyncMock()
152
+ engine.collect_time_ids = AsyncMock()
153
+ engine.collect_events = AsyncMock()
154
+ return engine
147
155
 
148
- async def extract_time_from_query(self, query: str):
149
- if "both" in query:
156
+
157
+ @pytest.fixture
158
+ def mock_vector_engine():
159
+ """Create a mock vector engine."""
160
+ engine = AsyncMock()
161
+ engine.embedding_engine = AsyncMock()
162
+ engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
163
+ engine.search = AsyncMock()
164
+ return engine
165
+
166
+
167
+ @pytest.mark.asyncio
168
+ async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine):
169
+ """Test get_context when time range is extracted from query."""
170
+ retriever = TemporalRetriever(top_k=5)
171
+
172
+ mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"]
173
+ mock_graph_engine.collect_events.return_value = [
174
+ {
175
+ "events": [
176
+ {"id": "e1", "description": "Event 1"},
177
+ {"id": "e2", "description": "Event 2"},
178
+ ]
179
+ }
180
+ ]
181
+
182
+ mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05)
183
+ mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10)
184
+ mock_vector_engine.search.return_value = [mock_result1, mock_result2]
185
+
186
+ with (
187
+ patch.object(
188
+ retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
189
+ ),
190
+ patch(
191
+ "cognee.modules.retrieval.temporal_retriever.get_graph_engine",
192
+ return_value=mock_graph_engine,
193
+ ),
194
+ patch(
195
+ "cognee.modules.retrieval.temporal_retriever.get_vector_engine",
196
+ return_value=mock_vector_engine,
197
+ ),
198
+ ):
199
+ context = await retriever.get_context("What happened in 2024?")
200
+
201
+ assert isinstance(context, str)
202
+ assert len(context) > 0
203
+ assert "Event" in context
204
+
205
+
206
+ @pytest.mark.asyncio
207
+ async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine):
208
+ """Test get_context falls back to triplets when no time is extracted."""
209
+ retriever = TemporalRetriever()
210
+
211
+ with (
212
+ patch(
213
+ "cognee.modules.retrieval.temporal_retriever.get_graph_engine",
214
+ return_value=mock_graph_engine,
215
+ ),
216
+ patch.object(
217
+ retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
218
+ ) as mock_get_triplets,
219
+ patch.object(
220
+ retriever, "resolve_edges_to_text", return_value="triplet text"
221
+ ) as mock_resolve,
222
+ ):
223
+
224
+ async def mock_extract_time(query):
225
+ return None, None
226
+
227
+ retriever.extract_time_from_query = mock_extract_time
228
+
229
+ context = await retriever.get_context("test query")
230
+
231
+ assert context == "triplet text"
232
+ mock_get_triplets.assert_awaited_once_with("test query")
233
+ mock_resolve.assert_awaited_once()
234
+
235
+
236
+ @pytest.mark.asyncio
237
+ async def test_get_context_no_events_found(mock_graph_engine):
238
+ """Test get_context falls back to triplets when no events are found."""
239
+ retriever = TemporalRetriever()
240
+
241
+ mock_graph_engine.collect_time_ids.return_value = []
242
+
243
+ with (
244
+ patch(
245
+ "cognee.modules.retrieval.temporal_retriever.get_graph_engine",
246
+ return_value=mock_graph_engine,
247
+ ),
248
+ patch.object(
249
+ retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
250
+ ) as mock_get_triplets,
251
+ patch.object(
252
+ retriever, "resolve_edges_to_text", return_value="triplet text"
253
+ ) as mock_resolve,
254
+ ):
255
+
256
+ async def mock_extract_time(query):
150
257
  return "2024-01-01", "2024-12-31"
151
- if "from_only" in query:
152
- return "2024-01-01", None
153
- if "to_only" in query:
154
- return None, "2024-12-31"
155
- return None, None
156
258
 
157
- async def get_triplets(self, query: str):
158
- self._calls.append(("get_triplets", query))
159
- return [{"s": "a", "p": "b", "o": "c"}]
259
+ retriever.extract_time_from_query = mock_extract_time
260
+
261
+ context = await retriever.get_context("test query")
262
+
263
+ assert context == "triplet text"
264
+ mock_get_triplets.assert_awaited_once_with("test query")
265
+ mock_resolve.assert_awaited_once()
266
+
267
+
268
+ @pytest.mark.asyncio
269
+ async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine):
270
+ """Test get_context with only time_from."""
271
+ retriever = TemporalRetriever(top_k=5)
272
+
273
+ mock_graph_engine.collect_time_ids.return_value = ["e1"]
274
+ mock_graph_engine.collect_events.return_value = [
275
+ {
276
+ "events": [
277
+ {"id": "e1", "description": "Event 1"},
278
+ ]
279
+ }
280
+ ]
281
+
282
+ mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
283
+ mock_vector_engine.search.return_value = [mock_result]
284
+
285
+ with (
286
+ patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)),
287
+ patch(
288
+ "cognee.modules.retrieval.temporal_retriever.get_graph_engine",
289
+ return_value=mock_graph_engine,
290
+ ),
291
+ patch(
292
+ "cognee.modules.retrieval.temporal_retriever.get_vector_engine",
293
+ return_value=mock_vector_engine,
294
+ ),
295
+ ):
296
+ context = await retriever.get_context("What happened after 2024?")
297
+
298
+ assert isinstance(context, str)
299
+ assert "Event 1" in context
300
+
301
+
302
+ @pytest.mark.asyncio
303
+ async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine):
304
+ """Test get_context with only time_to."""
305
+ retriever = TemporalRetriever(top_k=5)
306
+
307
+ mock_graph_engine.collect_time_ids.return_value = ["e1"]
308
+ mock_graph_engine.collect_events.return_value = [
309
+ {
310
+ "events": [
311
+ {"id": "e1", "description": "Event 1"},
312
+ ]
313
+ }
314
+ ]
315
+
316
+ mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
317
+ mock_vector_engine.search.return_value = [mock_result]
318
+
319
+ with (
320
+ patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")),
321
+ patch(
322
+ "cognee.modules.retrieval.temporal_retriever.get_graph_engine",
323
+ return_value=mock_graph_engine,
324
+ ),
325
+ patch(
326
+ "cognee.modules.retrieval.temporal_retriever.get_vector_engine",
327
+ return_value=mock_vector_engine,
328
+ ),
329
+ ):
330
+ context = await retriever.get_context("What happened before 2024?")
331
+
332
+ assert isinstance(context, str)
333
+ assert "Event 1" in context
334
+
335
+
336
+ @pytest.mark.asyncio
337
+ async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine):
338
+ """Test get_completion retrieves context when not provided."""
339
+ retriever = TemporalRetriever()
340
+
341
+ mock_graph_engine.collect_time_ids.return_value = ["e1"]
342
+ mock_graph_engine.collect_events.return_value = [
343
+ {
344
+ "events": [
345
+ {"id": "e1", "description": "Event 1"},
346
+ ]
347
+ }
348
+ ]
349
+
350
+ mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
351
+ mock_vector_engine.search.return_value = [mock_result]
352
+
353
+ with (
354
+ patch.object(
355
+ retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
356
+ ),
357
+ patch(
358
+ "cognee.modules.retrieval.temporal_retriever.get_graph_engine",
359
+ return_value=mock_graph_engine,
360
+ ),
361
+ patch(
362
+ "cognee.modules.retrieval.temporal_retriever.get_vector_engine",
363
+ return_value=mock_vector_engine,
364
+ ),
365
+ patch(
366
+ "cognee.modules.retrieval.temporal_retriever.generate_completion",
367
+ return_value="Generated answer",
368
+ ),
369
+ patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
370
+ ):
371
+ mock_config = MagicMock()
372
+ mock_config.caching = False
373
+ mock_cache_config.return_value = mock_config
374
+
375
+ completion = await retriever.get_completion("What happened in 2024?")
376
+
377
+ assert isinstance(completion, list)
378
+ assert len(completion) == 1
379
+ assert completion[0] == "Generated answer"
380
+
381
+
382
+ @pytest.mark.asyncio
383
+ async def test_get_completion_with_provided_context():
384
+ """Test get_completion uses provided context."""
385
+ retriever = TemporalRetriever()
386
+
387
+ with (
388
+ patch(
389
+ "cognee.modules.retrieval.temporal_retriever.generate_completion",
390
+ return_value="Generated answer",
391
+ ),
392
+ patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
393
+ ):
394
+ mock_config = MagicMock()
395
+ mock_config.caching = False
396
+ mock_cache_config.return_value = mock_config
397
+
398
+ completion = await retriever.get_completion("test query", context="Provided context")
399
+
400
+ assert isinstance(completion, list)
401
+ assert len(completion) == 1
402
+ assert completion[0] == "Generated answer"
403
+
404
+
405
+ @pytest.mark.asyncio
406
+ async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine):
407
+ """Test get_completion with session caching enabled."""
408
+ retriever = TemporalRetriever()
409
+
410
+ mock_graph_engine.collect_time_ids.return_value = ["e1"]
411
+ mock_graph_engine.collect_events.return_value = [
412
+ {
413
+ "events": [
414
+ {"id": "e1", "description": "Event 1"},
415
+ ]
416
+ }
417
+ ]
418
+
419
+ mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
420
+ mock_vector_engine.search.return_value = [mock_result]
421
+
422
+ mock_user = MagicMock()
423
+ mock_user.id = "test-user-id"
424
+
425
+ with (
426
+ patch.object(
427
+ retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
428
+ ),
429
+ patch(
430
+ "cognee.modules.retrieval.temporal_retriever.get_graph_engine",
431
+ return_value=mock_graph_engine,
432
+ ),
433
+ patch(
434
+ "cognee.modules.retrieval.temporal_retriever.get_vector_engine",
435
+ return_value=mock_vector_engine,
436
+ ),
437
+ patch(
438
+ "cognee.modules.retrieval.temporal_retriever.get_conversation_history",
439
+ return_value="Previous conversation",
440
+ ),
441
+ patch(
442
+ "cognee.modules.retrieval.temporal_retriever.summarize_text",
443
+ return_value="Context summary",
444
+ ),
445
+ patch(
446
+ "cognee.modules.retrieval.temporal_retriever.generate_completion",
447
+ return_value="Generated answer",
448
+ ),
449
+ patch(
450
+ "cognee.modules.retrieval.temporal_retriever.save_conversation_history",
451
+ ) as mock_save,
452
+ patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
453
+ patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
454
+ ):
455
+ mock_config = MagicMock()
456
+ mock_config.caching = True
457
+ mock_cache_config.return_value = mock_config
458
+ mock_session_user.get.return_value = mock_user
459
+
460
+ completion = await retriever.get_completion(
461
+ "What happened in 2024?", session_id="test_session"
462
+ )
160
463
 
161
- async def resolve_edges_to_text(self, triplets):
162
- self._calls.append(("resolve_edges_to_text", len(triplets)))
163
- return "edges->text"
464
+ assert isinstance(completion, list)
465
+ assert len(completion) == 1
466
+ assert completion[0] == "Generated answer"
467
+ mock_save.assert_awaited_once()
164
468
 
165
- async def _fake_graph_collect_ids(self, **kwargs):
166
- return ["e1", "e2"]
167
469
 
168
- async def _fake_graph_collect_events(self, ids):
169
- return [
470
+ @pytest.mark.asyncio
471
+ async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine):
472
+ """Test get_completion with session config but no user ID."""
473
+ retriever = TemporalRetriever()
474
+
475
+ mock_graph_engine.collect_time_ids.return_value = ["e1"]
476
+ mock_graph_engine.collect_events.return_value = [
477
+ {
478
+ "events": [
479
+ {"id": "e1", "description": "Event 1"},
480
+ ]
481
+ }
482
+ ]
483
+
484
+ mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
485
+ mock_vector_engine.search.return_value = [mock_result]
486
+
487
+ with (
488
+ patch.object(
489
+ retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
490
+ ),
491
+ patch(
492
+ "cognee.modules.retrieval.temporal_retriever.get_graph_engine",
493
+ return_value=mock_graph_engine,
494
+ ),
495
+ patch(
496
+ "cognee.modules.retrieval.temporal_retriever.get_vector_engine",
497
+ return_value=mock_vector_engine,
498
+ ),
499
+ patch(
500
+ "cognee.modules.retrieval.temporal_retriever.generate_completion",
501
+ return_value="Generated answer",
502
+ ),
503
+ patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
504
+ patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
505
+ ):
506
+ mock_config = MagicMock()
507
+ mock_config.caching = True
508
+ mock_cache_config.return_value = mock_config
509
+ mock_session_user.get.return_value = None # No user
510
+
511
+ completion = await retriever.get_completion("What happened in 2024?")
512
+
513
+ assert isinstance(completion, list)
514
+ assert len(completion) == 1
515
+
516
+
517
+ @pytest.mark.asyncio
518
+ async def test_get_completion_context_retrieved_but_empty(mock_graph_engine):
519
+ """Test get_completion when get_context returns empty string."""
520
+ retriever = TemporalRetriever()
521
+
522
+ with (
523
+ patch.object(
524
+ retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
525
+ ),
526
+ patch(
527
+ "cognee.modules.retrieval.temporal_retriever.get_graph_engine",
528
+ return_value=mock_graph_engine,
529
+ ),
530
+ patch(
531
+ "cognee.modules.retrieval.temporal_retriever.get_vector_engine",
532
+ ) as mock_get_vector,
533
+ patch.object(retriever, "filter_top_k_events", return_value=[]),
534
+ ):
535
+ mock_vector_engine = AsyncMock()
536
+ mock_vector_engine.embedding_engine = AsyncMock()
537
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
538
+ mock_vector_engine.search = AsyncMock(return_value=[])
539
+ mock_get_vector.return_value = mock_vector_engine
540
+
541
+ mock_graph_engine.collect_time_ids.return_value = ["e1"]
542
+ mock_graph_engine.collect_events.return_value = [
170
543
  {
171
544
  "events": [
172
- {"id": "e1", "description": "E1"},
173
- {"id": "e2", "description": "E2"},
174
- {"id": "e3", "description": "E3"},
545
+ {"id": "e1", "description": ""},
175
546
  ]
176
547
  }
177
548
  ]
178
549
 
179
- async def _fake_vector_embed(self, texts):
180
- assert isinstance(texts, list) and texts
181
- return [[0.0, 1.0, 2.0]]
550
+ with pytest.raises((UnboundLocalError, NameError)):
551
+ await retriever.get_completion("test query")
182
552
 
183
- async def _fake_vector_search(self, **kwargs):
184
- return [
185
- SimpleNamespace(payload={"id": "e2"}, score=0.05),
186
- SimpleNamespace(payload={"id": "e1"}, score=0.10),
187
- ]
188
553
 
189
- async def get_context(self, query: str):
190
- time_from, time_to = await self.extract_time_from_query(query)
554
+ @pytest.mark.asyncio
555
+ async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine):
556
+ """Test get_completion with custom response model."""
557
+ from pydantic import BaseModel
191
558
 
192
- if not (time_from or time_to):
193
- triplets = await self.get_triplets(query)
194
- return await self.resolve_edges_to_text(triplets)
559
+ class TestModel(BaseModel):
560
+ answer: str
195
561
 
196
- ids = await self._fake_graph_collect_ids(time_from=time_from, time_to=time_to)
197
- relevant_events = await self._fake_graph_collect_events(ids)
562
+ retriever = TemporalRetriever()
198
563
 
199
- _ = await self._fake_vector_embed([query])
200
- vector_search_results = await self._fake_vector_search(
201
- collection_name="Event_name", query_vector=[0.0], limit=0
564
+ mock_graph_engine.collect_time_ids.return_value = ["e1"]
565
+ mock_graph_engine.collect_events.return_value = [
566
+ {
567
+ "events": [
568
+ {"id": "e1", "description": "Event 1"},
569
+ ]
570
+ }
571
+ ]
572
+
573
+ mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
574
+ mock_vector_engine.search.return_value = [mock_result]
575
+
576
+ with (
577
+ patch.object(
578
+ retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
579
+ ),
580
+ patch(
581
+ "cognee.modules.retrieval.temporal_retriever.get_graph_engine",
582
+ return_value=mock_graph_engine,
583
+ ),
584
+ patch(
585
+ "cognee.modules.retrieval.temporal_retriever.get_vector_engine",
586
+ return_value=mock_vector_engine,
587
+ ),
588
+ patch(
589
+ "cognee.modules.retrieval.temporal_retriever.generate_completion",
590
+ return_value=TestModel(answer="Test answer"),
591
+ ),
592
+ patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
593
+ ):
594
+ mock_config = MagicMock()
595
+ mock_config.caching = False
596
+ mock_cache_config.return_value = mock_config
597
+
598
+ completion = await retriever.get_completion(
599
+ "What happened in 2024?", response_model=TestModel
202
600
  )
203
- top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
204
- return self.descriptions_to_string(top_k_events)
601
+
602
+ assert isinstance(completion, list)
603
+ assert len(completion) == 1
604
+ assert isinstance(completion[0], TestModel)
605
+
606
+
607
+ @pytest.mark.asyncio
608
+ async def test_extract_time_from_query_relative_path():
609
+ """Test extract_time_from_query with relative prompt path."""
610
+ retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
611
+
612
+ mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
613
+ mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
614
+ mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
615
+
616
+ with (
617
+ patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
618
+ patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
619
+ patch(
620
+ "cognee.modules.retrieval.temporal_retriever.render_prompt",
621
+ return_value="System prompt",
622
+ ),
623
+ patch.object(
624
+ LLMGateway,
625
+ "acreate_structured_output",
626
+ new_callable=AsyncMock,
627
+ return_value=mock_interval,
628
+ ),
629
+ ):
630
+ mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
631
+
632
+ time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
633
+
634
+ assert time_from == mock_timestamp_from
635
+ assert time_to == mock_timestamp_to
205
636
 
206
637
 
207
- # Test get_context fallback to triplets when no time is extracted
208
638
  @pytest.mark.asyncio
209
- async def test_fake_get_context_falls_back_to_triplets_when_no_time():
210
- tr = _FakeRetriever(top_k=2)
211
- ctx = await tr.get_context("no_time")
212
- assert ctx == "edges->text"
213
- assert tr._calls[0][0] == "get_triplets"
214
- assert tr._calls[1][0] == "resolve_edges_to_text"
639
+ async def test_extract_time_from_query_absolute_path():
640
+ """Test extract_time_from_query with absolute prompt path."""
641
+ retriever = TemporalRetriever(
642
+ time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt"
643
+ )
644
+
645
+ mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
646
+ mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
647
+ mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
648
+
649
+ with (
650
+ patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True),
651
+ patch(
652
+ "cognee.modules.retrieval.temporal_retriever.os.path.dirname",
653
+ return_value="/absolute/path/to",
654
+ ),
655
+ patch(
656
+ "cognee.modules.retrieval.temporal_retriever.os.path.basename",
657
+ return_value="extract_query_time.txt",
658
+ ),
659
+ patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
660
+ patch(
661
+ "cognee.modules.retrieval.temporal_retriever.render_prompt",
662
+ return_value="System prompt",
663
+ ),
664
+ patch.object(
665
+ LLMGateway,
666
+ "acreate_structured_output",
667
+ new_callable=AsyncMock,
668
+ return_value=mock_interval,
669
+ ),
670
+ ):
671
+ mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
672
+
673
+ time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
674
+
675
+ assert time_from == mock_timestamp_from
676
+ assert time_to == mock_timestamp_to
215
677
 
216
678
 
217
- # Test get_context when time is extracted and vector ranking is applied
218
679
  @pytest.mark.asyncio
219
- async def test_fake_get_context_with_time_filters_and_vector_ranking():
220
- tr = _FakeRetriever(top_k=2)
221
- ctx = await tr.get_context("both time")
222
- assert ctx.startswith("E2")
223
- assert "#####################" in ctx
224
- assert "E1" in ctx and "E3" not in ctx
680
+ async def test_extract_time_from_query_with_none_values():
681
+ """Test extract_time_from_query when interval has None values."""
682
+ retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
683
+
684
+ mock_interval = QueryInterval(starts_at=None, ends_at=None)
685
+
686
+ with (
687
+ patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
688
+ patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
689
+ patch(
690
+ "cognee.modules.retrieval.temporal_retriever.render_prompt",
691
+ return_value="System prompt",
692
+ ),
693
+ patch.object(
694
+ LLMGateway,
695
+ "acreate_structured_output",
696
+ new_callable=AsyncMock,
697
+ return_value=mock_interval,
698
+ ),
699
+ ):
700
+ mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
701
+
702
+ time_from, time_to = await retriever.extract_time_from_query("What happened?")
703
+
704
+ assert time_from is None
705
+ assert time_to is None