cognee 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. cognee/__init__.py +1 -0
  2. cognee/api/health.py +2 -12
  3. cognee/api/v1/add/add.py +46 -6
  4. cognee/api/v1/add/routers/get_add_router.py +5 -1
  5. cognee/api/v1/cognify/cognify.py +29 -9
  6. cognee/api/v1/datasets/datasets.py +11 -0
  7. cognee/api/v1/responses/default_tools.py +0 -1
  8. cognee/api/v1/responses/dispatch_function.py +1 -1
  9. cognee/api/v1/responses/routers/default_tools.py +0 -1
  10. cognee/api/v1/search/search.py +11 -9
  11. cognee/api/v1/settings/routers/get_settings_router.py +7 -1
  12. cognee/api/v1/ui/ui.py +47 -16
  13. cognee/api/v1/update/routers/get_update_router.py +1 -1
  14. cognee/api/v1/update/update.py +3 -3
  15. cognee/cli/_cognee.py +61 -10
  16. cognee/cli/commands/add_command.py +3 -3
  17. cognee/cli/commands/cognify_command.py +3 -3
  18. cognee/cli/commands/config_command.py +9 -7
  19. cognee/cli/commands/delete_command.py +3 -3
  20. cognee/cli/commands/search_command.py +3 -7
  21. cognee/cli/config.py +0 -1
  22. cognee/context_global_variables.py +5 -0
  23. cognee/exceptions/exceptions.py +1 -1
  24. cognee/infrastructure/databases/cache/__init__.py +2 -0
  25. cognee/infrastructure/databases/cache/cache_db_interface.py +79 -0
  26. cognee/infrastructure/databases/cache/config.py +44 -0
  27. cognee/infrastructure/databases/cache/get_cache_engine.py +67 -0
  28. cognee/infrastructure/databases/cache/redis/RedisAdapter.py +243 -0
  29. cognee/infrastructure/databases/exceptions/__init__.py +1 -0
  30. cognee/infrastructure/databases/exceptions/exceptions.py +18 -2
  31. cognee/infrastructure/databases/graph/get_graph_engine.py +1 -1
  32. cognee/infrastructure/databases/graph/graph_db_interface.py +5 -0
  33. cognee/infrastructure/databases/graph/kuzu/adapter.py +67 -44
  34. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +13 -3
  35. cognee/infrastructure/databases/graph/neo4j_driver/deadlock_retry.py +1 -1
  36. cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py +1 -1
  37. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +1 -1
  38. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +21 -3
  39. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +17 -10
  40. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +17 -4
  41. cognee/infrastructure/databases/vector/embeddings/config.py +2 -3
  42. cognee/infrastructure/databases/vector/exceptions/exceptions.py +1 -1
  43. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +0 -1
  44. cognee/infrastructure/files/exceptions.py +1 -1
  45. cognee/infrastructure/files/storage/LocalFileStorage.py +9 -9
  46. cognee/infrastructure/files/storage/S3FileStorage.py +11 -11
  47. cognee/infrastructure/files/utils/guess_file_type.py +6 -0
  48. cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +0 -5
  49. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +19 -9
  50. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +17 -5
  51. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +17 -5
  52. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +32 -0
  53. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/__init__.py +0 -0
  54. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +109 -0
  55. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +33 -8
  56. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +40 -18
  57. cognee/infrastructure/loaders/LoaderEngine.py +27 -7
  58. cognee/infrastructure/loaders/external/__init__.py +7 -0
  59. cognee/infrastructure/loaders/external/advanced_pdf_loader.py +2 -8
  60. cognee/infrastructure/loaders/external/beautiful_soup_loader.py +310 -0
  61. cognee/infrastructure/loaders/supported_loaders.py +7 -0
  62. cognee/modules/data/exceptions/exceptions.py +1 -1
  63. cognee/modules/data/methods/__init__.py +3 -0
  64. cognee/modules/data/methods/get_dataset_data.py +4 -1
  65. cognee/modules/data/methods/has_dataset_data.py +21 -0
  66. cognee/modules/engine/models/TableRow.py +0 -1
  67. cognee/modules/ingestion/save_data_to_file.py +9 -2
  68. cognee/modules/pipelines/exceptions/exceptions.py +1 -1
  69. cognee/modules/pipelines/operations/pipeline.py +12 -1
  70. cognee/modules/pipelines/operations/run_tasks.py +25 -197
  71. cognee/modules/pipelines/operations/run_tasks_data_item.py +260 -0
  72. cognee/modules/pipelines/operations/run_tasks_distributed.py +121 -38
  73. cognee/modules/retrieval/EntityCompletionRetriever.py +48 -8
  74. cognee/modules/retrieval/base_graph_retriever.py +3 -1
  75. cognee/modules/retrieval/base_retriever.py +3 -1
  76. cognee/modules/retrieval/chunks_retriever.py +5 -1
  77. cognee/modules/retrieval/code_retriever.py +20 -2
  78. cognee/modules/retrieval/completion_retriever.py +50 -9
  79. cognee/modules/retrieval/cypher_search_retriever.py +11 -1
  80. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +47 -8
  81. cognee/modules/retrieval/graph_completion_cot_retriever.py +32 -1
  82. cognee/modules/retrieval/graph_completion_retriever.py +54 -10
  83. cognee/modules/retrieval/lexical_retriever.py +20 -2
  84. cognee/modules/retrieval/natural_language_retriever.py +10 -1
  85. cognee/modules/retrieval/summaries_retriever.py +5 -1
  86. cognee/modules/retrieval/temporal_retriever.py +62 -10
  87. cognee/modules/retrieval/user_qa_feedback.py +3 -2
  88. cognee/modules/retrieval/utils/completion.py +5 -0
  89. cognee/modules/retrieval/utils/description_to_codepart_search.py +1 -1
  90. cognee/modules/retrieval/utils/session_cache.py +156 -0
  91. cognee/modules/search/methods/get_search_type_tools.py +0 -5
  92. cognee/modules/search/methods/no_access_control_search.py +12 -1
  93. cognee/modules/search/methods/search.py +34 -2
  94. cognee/modules/search/types/SearchType.py +0 -1
  95. cognee/modules/settings/get_settings.py +23 -0
  96. cognee/modules/users/methods/get_authenticated_user.py +3 -1
  97. cognee/modules/users/methods/get_default_user.py +1 -6
  98. cognee/modules/users/roles/methods/create_role.py +2 -2
  99. cognee/modules/users/tenants/methods/create_tenant.py +2 -2
  100. cognee/shared/exceptions/exceptions.py +1 -1
  101. cognee/tasks/codingagents/coding_rule_associations.py +1 -2
  102. cognee/tasks/documents/exceptions/exceptions.py +1 -1
  103. cognee/tasks/graph/extract_graph_from_data.py +2 -0
  104. cognee/tasks/ingestion/data_item_to_text_file.py +3 -3
  105. cognee/tasks/ingestion/ingest_data.py +11 -5
  106. cognee/tasks/ingestion/save_data_item_to_storage.py +12 -1
  107. cognee/tasks/storage/add_data_points.py +3 -10
  108. cognee/tasks/storage/index_data_points.py +19 -14
  109. cognee/tasks/storage/index_graph_edges.py +25 -11
  110. cognee/tasks/web_scraper/__init__.py +34 -0
  111. cognee/tasks/web_scraper/config.py +26 -0
  112. cognee/tasks/web_scraper/default_url_crawler.py +446 -0
  113. cognee/tasks/web_scraper/models.py +46 -0
  114. cognee/tasks/web_scraper/types.py +4 -0
  115. cognee/tasks/web_scraper/utils.py +142 -0
  116. cognee/tasks/web_scraper/web_scraper_task.py +396 -0
  117. cognee/tests/cli_tests/cli_unit_tests/test_cli_utils.py +0 -1
  118. cognee/tests/integration/web_url_crawler/test_default_url_crawler.py +13 -0
  119. cognee/tests/integration/web_url_crawler/test_tavily_crawler.py +19 -0
  120. cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py +344 -0
  121. cognee/tests/subprocesses/reader.py +25 -0
  122. cognee/tests/subprocesses/simple_cognify_1.py +31 -0
  123. cognee/tests/subprocesses/simple_cognify_2.py +31 -0
  124. cognee/tests/subprocesses/writer.py +32 -0
  125. cognee/tests/tasks/descriptive_metrics/metrics_test_utils.py +0 -2
  126. cognee/tests/tasks/descriptive_metrics/neo4j_metrics_test.py +8 -3
  127. cognee/tests/tasks/entity_extraction/entity_extraction_test.py +89 -0
  128. cognee/tests/tasks/web_scraping/web_scraping_test.py +172 -0
  129. cognee/tests/test_add_docling_document.py +56 -0
  130. cognee/tests/test_chromadb.py +7 -11
  131. cognee/tests/test_concurrent_subprocess_access.py +76 -0
  132. cognee/tests/test_conversation_history.py +240 -0
  133. cognee/tests/test_kuzu.py +27 -15
  134. cognee/tests/test_lancedb.py +7 -11
  135. cognee/tests/test_library.py +32 -2
  136. cognee/tests/test_neo4j.py +24 -16
  137. cognee/tests/test_neptune_analytics_vector.py +7 -11
  138. cognee/tests/test_permissions.py +9 -13
  139. cognee/tests/test_pgvector.py +4 -4
  140. cognee/tests/test_remote_kuzu.py +8 -11
  141. cognee/tests/test_s3_file_storage.py +1 -1
  142. cognee/tests/test_search_db.py +6 -8
  143. cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +89 -0
  144. cognee/tests/unit/modules/retrieval/conversation_history_test.py +154 -0
  145. {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/METADATA +22 -7
  146. {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/RECORD +155 -128
  147. {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/entry_points.txt +1 -0
  148. distributed/Dockerfile +0 -3
  149. distributed/entrypoint.py +21 -9
  150. distributed/signal.py +5 -0
  151. distributed/workers/data_point_saving_worker.py +64 -34
  152. distributed/workers/graph_saving_worker.py +71 -47
  153. cognee/infrastructure/databases/graph/memgraph/memgraph_adapter.py +0 -1116
  154. cognee/modules/retrieval/insights_retriever.py +0 -133
  155. cognee/tests/test_memgraph.py +0 -109
  156. cognee/tests/unit/modules/retrieval/insights_retriever_test.py +0 -251
  157. distributed/poetry.lock +0 -12238
  158. distributed/pyproject.toml +0 -185
  159. {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/WHEEL +0 -0
  160. {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/licenses/LICENSE +0 -0
  161. {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/licenses/NOTICE.md +0 -0
@@ -3,49 +3,97 @@ try:
3
3
  except ModuleNotFoundError:
4
4
  modal = None
5
5
 
6
+ from typing import Any, List, Optional
7
+ from uuid import UUID
8
+
9
+ from cognee.modules.pipelines.tasks.task import Task
6
10
  from cognee.infrastructure.databases.relational import get_relational_engine
11
+ from cognee.infrastructure.databases.graph import get_graph_engine
7
12
  from cognee.modules.pipelines.models import (
8
13
  PipelineRunStarted,
9
- PipelineRunYield,
10
14
  PipelineRunCompleted,
15
+ PipelineRunErrored,
16
+ )
17
+ from cognee.modules.pipelines.operations import (
18
+ log_pipeline_run_start,
19
+ log_pipeline_run_complete,
20
+ log_pipeline_run_error,
11
21
  )
12
- from cognee.modules.pipelines.operations import log_pipeline_run_start, log_pipeline_run_complete
13
- from cognee.modules.pipelines.utils.generate_pipeline_id import generate_pipeline_id
22
+ from cognee.modules.pipelines.utils import generate_pipeline_id
14
23
  from cognee.modules.users.methods import get_default_user
15
24
  from cognee.shared.logging_utils import get_logger
16
-
17
- from .run_tasks_with_telemetry import run_tasks_with_telemetry
18
-
25
+ from cognee.modules.users.models import User
26
+ from cognee.modules.pipelines.exceptions import PipelineRunFailedError
27
+ from cognee.tasks.ingestion import resolve_data_directories
28
+ from .run_tasks_data_item import run_tasks_data_item
19
29
 
20
30
  logger = get_logger("run_tasks_distributed()")
21
31
 
22
-
23
32
  if modal:
33
+ import os
24
34
  from distributed.app import app
25
35
  from distributed.modal_image import image
26
36
 
37
+ secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
38
+
27
39
  @app.function(
28
40
  retries=3,
29
41
  image=image,
30
42
  timeout=86400,
31
43
  max_containers=50,
32
- secrets=[modal.Secret.from_name("distributed_cognee")],
44
+ secrets=[modal.Secret.from_name(secret_name)],
33
45
  )
34
- async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context):
35
- pipeline_run = run_tasks_with_telemetry(tasks, data_item, user, pipeline_name, context)
36
-
37
- run_info = None
38
-
39
- async for pipeline_run_info in pipeline_run:
40
- run_info = pipeline_run_info
46
+ async def run_tasks_on_modal(
47
+ data_item,
48
+ dataset_id: UUID,
49
+ tasks: List[Task],
50
+ pipeline_name: str,
51
+ pipeline_id: str,
52
+ pipeline_run_id: str,
53
+ context: Optional[dict],
54
+ user: User,
55
+ incremental_loading: bool,
56
+ ):
57
+ """
58
+ Wrapper that runs the run_tasks_data_item function.
59
+ This is the function/code that runs on modal executor and produces the graph/vector db objects
60
+ """
61
+ from cognee.infrastructure.databases.relational import get_relational_engine
62
+
63
+ async with get_relational_engine().get_async_session() as session:
64
+ from cognee.modules.data.models import Dataset
65
+
66
+ dataset = await session.get(Dataset, dataset_id)
67
+
68
+ result = await run_tasks_data_item(
69
+ data_item=data_item,
70
+ dataset=dataset,
71
+ tasks=tasks,
72
+ pipeline_name=pipeline_name,
73
+ pipeline_id=pipeline_id,
74
+ pipeline_run_id=pipeline_run_id,
75
+ context=context,
76
+ user=user,
77
+ incremental_loading=incremental_loading,
78
+ )
41
79
 
42
- return run_info
80
+ return result
43
81
 
44
82
 
45
- async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, context):
83
+ async def run_tasks_distributed(
84
+ tasks: List[Task],
85
+ dataset_id: UUID,
86
+ data: List[Any] = None,
87
+ user: User = None,
88
+ pipeline_name: str = "unknown_pipeline",
89
+ context: dict = None,
90
+ incremental_loading: bool = False,
91
+ data_per_batch: int = 20,
92
+ ):
46
93
  if not user:
47
94
  user = await get_default_user()
48
95
 
96
+ # Get dataset object
49
97
  db_engine = get_relational_engine()
50
98
  async with db_engine.get_async_session() as session:
51
99
  from cognee.modules.data.models import Dataset
@@ -53,9 +101,7 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
53
101
  dataset = await session.get(Dataset, dataset_id)
54
102
 
55
103
  pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
56
-
57
104
  pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
58
-
59
105
  pipeline_run_id = pipeline_run.pipeline_run_id
60
106
 
61
107
  yield PipelineRunStarted(
@@ -65,30 +111,67 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
65
111
  payload=data,
66
112
  )
67
113
 
68
- data_count = len(data) if isinstance(data, list) else 1
114
+ try:
115
+ if not isinstance(data, list):
116
+ data = [data]
117
+
118
+ data = await resolve_data_directories(data)
119
+
120
+ number_of_data_items = len(data) if isinstance(data, list) else 1
121
+
122
+ data_item_tasks = [
123
+ data,
124
+ [dataset.id] * number_of_data_items,
125
+ [tasks] * number_of_data_items,
126
+ [pipeline_name] * number_of_data_items,
127
+ [pipeline_id] * number_of_data_items,
128
+ [pipeline_run_id] * number_of_data_items,
129
+ [context] * number_of_data_items,
130
+ [user] * number_of_data_items,
131
+ [incremental_loading] * number_of_data_items,
132
+ ]
133
+
134
+ results = []
135
+ async for result in run_tasks_on_modal.map.aio(*data_item_tasks):
136
+ if not result:
137
+ continue
138
+ results.append(result)
139
+
140
+ # Remove skipped results
141
+ results = [r for r in results if r]
142
+
143
+ # If any data item failed, raise PipelineRunFailedError
144
+ errored = [
145
+ r
146
+ for r in results
147
+ if r and r.get("run_info") and isinstance(r["run_info"], PipelineRunErrored)
148
+ ]
149
+ if errored:
150
+ raise PipelineRunFailedError("Pipeline run failed. Data item could not be processed.")
151
+
152
+ await log_pipeline_run_complete(
153
+ pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
154
+ )
69
155
 
70
- arguments = [
71
- [tasks] * data_count,
72
- [[data_item] for data_item in data[:data_count]] if data_count > 1 else [data],
73
- [user] * data_count,
74
- [pipeline_name] * data_count,
75
- [context] * data_count,
76
- ]
156
+ yield PipelineRunCompleted(
157
+ pipeline_run_id=pipeline_run_id,
158
+ dataset_id=dataset.id,
159
+ dataset_name=dataset.name,
160
+ data_ingestion_info=results,
161
+ )
77
162
 
78
- async for result in run_tasks_on_modal.map.aio(*arguments):
79
- logger.info(f"Received result: {result}")
163
+ except Exception as error:
164
+ await log_pipeline_run_error(
165
+ pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
166
+ )
80
167
 
81
- yield PipelineRunYield(
168
+ yield PipelineRunErrored(
82
169
  pipeline_run_id=pipeline_run_id,
170
+ payload=repr(error),
83
171
  dataset_id=dataset.id,
84
172
  dataset_name=dataset.name,
85
- payload=result,
173
+ data_ingestion_info=locals().get("results"),
86
174
  )
87
175
 
88
- await log_pipeline_run_complete(pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data)
89
-
90
- yield PipelineRunCompleted(
91
- pipeline_run_id=pipeline_run_id,
92
- dataset_id=dataset.id,
93
- dataset_name=dataset.name,
94
- )
176
+ if not isinstance(error, PipelineRunFailedError):
177
+ raise
@@ -1,10 +1,17 @@
1
+ import asyncio
1
2
  from typing import Any, Optional, List
2
3
  from cognee.shared.logging_utils import get_logger
3
4
 
4
5
  from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
5
6
  from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider
6
7
  from cognee.modules.retrieval.base_retriever import BaseRetriever
7
- from cognee.modules.retrieval.utils.completion import generate_completion
8
+ from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
9
+ from cognee.modules.retrieval.utils.session_cache import (
10
+ save_conversation_history,
11
+ get_conversation_history,
12
+ )
13
+ from cognee.context_global_variables import session_user
14
+ from cognee.infrastructure.databases.cache.config import CacheConfig
8
15
 
9
16
 
10
17
  logger = get_logger("entity_completion_retriever")
@@ -77,7 +84,9 @@ class EntityCompletionRetriever(BaseRetriever):
77
84
  logger.error(f"Context retrieval failed: {str(e)}")
78
85
  return None
79
86
 
80
- async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]:
87
+ async def get_completion(
88
+ self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
89
+ ) -> List[str]:
81
90
  """
82
91
  Generate completion using provided context or fetch new context.
83
92
 
@@ -91,6 +100,8 @@ class EntityCompletionRetriever(BaseRetriever):
91
100
  - query (str): The query string for which completion is being generated.
92
101
  - context (Optional[Any]): Optional context to be used for generating completion;
93
102
  fetched if not provided. (default None)
103
+ - session_id (Optional[str]): Optional session identifier for caching. If None,
104
+ defaults to 'default_session'. (default None)
94
105
 
95
106
  Returns:
96
107
  --------
@@ -105,12 +116,41 @@ class EntityCompletionRetriever(BaseRetriever):
105
116
  if context is None:
106
117
  return ["No relevant entities found for the query."]
107
118
 
108
- completion = await generate_completion(
109
- query=query,
110
- context=context,
111
- user_prompt_path=self.user_prompt_path,
112
- system_prompt_path=self.system_prompt_path,
113
- )
119
+ # Check if we need to generate context summary for caching
120
+ cache_config = CacheConfig()
121
+ user = session_user.get()
122
+ user_id = getattr(user, "id", None)
123
+ session_save = user_id and cache_config.caching
124
+
125
+ if session_save:
126
+ conversation_history = await get_conversation_history(session_id=session_id)
127
+
128
+ context_summary, completion = await asyncio.gather(
129
+ summarize_text(str(context)),
130
+ generate_completion(
131
+ query=query,
132
+ context=context,
133
+ user_prompt_path=self.user_prompt_path,
134
+ system_prompt_path=self.system_prompt_path,
135
+ conversation_history=conversation_history,
136
+ ),
137
+ )
138
+ else:
139
+ completion = await generate_completion(
140
+ query=query,
141
+ context=context,
142
+ user_prompt_path=self.user_prompt_path,
143
+ system_prompt_path=self.system_prompt_path,
144
+ )
145
+
146
+ if session_save:
147
+ await save_conversation_history(
148
+ query=query,
149
+ context_summary=context_summary,
150
+ answer=completion,
151
+ session_id=session_id,
152
+ )
153
+
114
154
  return [completion]
115
155
 
116
156
  except Exception as e:
@@ -13,6 +13,8 @@ class BaseGraphRetriever(ABC):
13
13
  pass
14
14
 
15
15
  @abstractmethod
16
- async def get_completion(self, query: str, context: Optional[List[Edge]] = None) -> str:
16
+ async def get_completion(
17
+ self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None
18
+ ) -> str:
17
19
  """Generates a response using the query and optional context (triplets)."""
18
20
  pass
@@ -11,6 +11,8 @@ class BaseRetriever(ABC):
11
11
  pass
12
12
 
13
13
  @abstractmethod
14
- async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
14
+ async def get_completion(
15
+ self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
16
+ ) -> Any:
15
17
  """Generates a response using the query and optional context."""
16
18
  pass
@@ -61,7 +61,9 @@ class ChunksRetriever(BaseRetriever):
61
61
  logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
62
62
  return chunk_payloads
63
63
 
64
- async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
64
+ async def get_completion(
65
+ self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
66
+ ) -> Any:
65
67
  """
66
68
  Generates a completion using document chunks context.
67
69
 
@@ -74,6 +76,8 @@ class ChunksRetriever(BaseRetriever):
74
76
  - query (str): The query string to be used for generating a completion.
75
77
  - context (Optional[Any]): Optional pre-fetched context to use for generating the
76
78
  completion; if None, it retrieves the context for the query. (default None)
79
+ - session_id (Optional[str]): Optional session identifier for caching. If None,
80
+ defaults to 'default_session'. (default None)
77
81
 
78
82
  Returns:
79
83
  --------
@@ -207,8 +207,26 @@ class CodeRetriever(BaseRetriever):
207
207
  logger.info(f"Returning {len(result)} code file contexts")
208
208
  return result
209
209
 
210
- async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
211
- """Returns the code files context."""
210
+ async def get_completion(
211
+ self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
212
+ ) -> Any:
213
+ """
214
+ Returns the code files context.
215
+
216
+ Parameters:
217
+ -----------
218
+
219
+ - query (str): The query string to retrieve code context for.
220
+ - context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
221
+ the context for the query. (default None)
222
+ - session_id (Optional[str]): Optional session identifier for caching. If None,
223
+ defaults to 'default_session'. (default None)
224
+
225
+ Returns:
226
+ --------
227
+
228
+ - Any: The code files context, either provided or retrieved.
229
+ """
212
230
  if context is None:
213
231
  context = await self.get_context(query)
214
232
  return context
@@ -1,11 +1,18 @@
1
+ import asyncio
1
2
  from typing import Any, Optional
2
3
 
3
4
  from cognee.shared.logging_utils import get_logger
4
5
  from cognee.infrastructure.databases.vector import get_vector_engine
5
- from cognee.modules.retrieval.utils.completion import generate_completion
6
+ from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
7
+ from cognee.modules.retrieval.utils.session_cache import (
8
+ save_conversation_history,
9
+ get_conversation_history,
10
+ )
6
11
  from cognee.modules.retrieval.base_retriever import BaseRetriever
7
12
  from cognee.modules.retrieval.exceptions.exceptions import NoDataError
8
13
  from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
14
+ from cognee.context_global_variables import session_user
15
+ from cognee.infrastructure.databases.cache.config import CacheConfig
9
16
 
10
17
  logger = get_logger("CompletionRetriever")
11
18
 
@@ -67,7 +74,9 @@ class CompletionRetriever(BaseRetriever):
67
74
  logger.error("DocumentChunk_text collection not found")
68
75
  raise NoDataError("No data found in the system, please add data first.") from error
69
76
 
70
- async def get_completion(self, query: str, context: Optional[Any] = None) -> str:
77
+ async def get_completion(
78
+ self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
79
+ ) -> str:
71
80
  """
72
81
  Generates an LLM completion using the context.
73
82
 
@@ -80,6 +89,8 @@ class CompletionRetriever(BaseRetriever):
80
89
  - query (str): The query string to be used for generating a completion.
81
90
  - context (Optional[Any]): Optional pre-fetched context to use for generating the
82
91
  completion; if None, it retrieves the context for the query. (default None)
92
+ - session_id (Optional[str]): Optional session identifier for caching. If None,
93
+ defaults to 'default_session'. (default None)
83
94
 
84
95
  Returns:
85
96
  --------
@@ -89,11 +100,41 @@ class CompletionRetriever(BaseRetriever):
89
100
  if context is None:
90
101
  context = await self.get_context(query)
91
102
 
92
- completion = await generate_completion(
93
- query=query,
94
- context=context,
95
- user_prompt_path=self.user_prompt_path,
96
- system_prompt_path=self.system_prompt_path,
97
- system_prompt=self.system_prompt,
98
- )
103
+ # Check if we need to generate context summary for caching
104
+ cache_config = CacheConfig()
105
+ user = session_user.get()
106
+ user_id = getattr(user, "id", None)
107
+ session_save = user_id and cache_config.caching
108
+
109
+ if session_save:
110
+ conversation_history = await get_conversation_history(session_id=session_id)
111
+
112
+ context_summary, completion = await asyncio.gather(
113
+ summarize_text(context),
114
+ generate_completion(
115
+ query=query,
116
+ context=context,
117
+ user_prompt_path=self.user_prompt_path,
118
+ system_prompt_path=self.system_prompt_path,
119
+ system_prompt=self.system_prompt,
120
+ conversation_history=conversation_history,
121
+ ),
122
+ )
123
+ else:
124
+ completion = await generate_completion(
125
+ query=query,
126
+ context=context,
127
+ user_prompt_path=self.user_prompt_path,
128
+ system_prompt_path=self.system_prompt_path,
129
+ system_prompt=self.system_prompt,
130
+ )
131
+
132
+ if session_save:
133
+ await save_conversation_history(
134
+ query=query,
135
+ context_summary=context_summary,
136
+ answer=completion,
137
+ session_id=session_id,
138
+ )
139
+
99
140
  return completion
@@ -44,13 +44,21 @@ class CypherSearchRetriever(BaseRetriever):
44
44
  """
45
45
  try:
46
46
  graph_engine = await get_graph_engine()
47
+ is_empty = await graph_engine.is_empty()
48
+
49
+ if is_empty:
50
+ logger.warning("Search attempt on an empty knowledge graph")
51
+ return []
52
+
47
53
  result = await graph_engine.query(query)
48
54
  except Exception as e:
49
55
  logger.error("Failed to execture cypher search retrieval: %s", str(e))
50
56
  raise CypherSearchError() from e
51
57
  return result
52
58
 
53
- async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
59
+ async def get_completion(
60
+ self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
61
+ ) -> Any:
54
62
  """
55
63
  Returns the graph connections context.
56
64
 
@@ -62,6 +70,8 @@ class CypherSearchRetriever(BaseRetriever):
62
70
  - query (str): The query to retrieve context.
63
71
  - context (Optional[Any]): Optional context to use, otherwise fetched using the
64
72
  query. (default None)
73
+ - session_id (Optional[str]): Optional session identifier for caching. If None,
74
+ defaults to 'default_session'. (default None)
65
75
 
66
76
  Returns:
67
77
  --------
@@ -1,8 +1,15 @@
1
+ import asyncio
1
2
  from typing import Optional, List, Type
2
3
  from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
3
4
  from cognee.shared.logging_utils import get_logger
4
5
  from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
5
- from cognee.modules.retrieval.utils.completion import generate_completion
6
+ from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
7
+ from cognee.modules.retrieval.utils.session_cache import (
8
+ save_conversation_history,
9
+ get_conversation_history,
10
+ )
11
+ from cognee.context_global_variables import session_user
12
+ from cognee.infrastructure.databases.cache.config import CacheConfig
6
13
 
7
14
  logger = get_logger()
8
15
 
@@ -47,6 +54,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
47
54
  self,
48
55
  query: str,
49
56
  context: Optional[List[Edge]] = None,
57
+ session_id: Optional[str] = None,
50
58
  context_extension_rounds=4,
51
59
  ) -> List[str]:
52
60
  """
@@ -64,6 +72,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
64
72
  - query (str): The input query for which the completion is generated.
65
73
  - context (Optional[Any]): The existing context to use for enhancing the query; if
66
74
  None, it will be initialized from triplets generated for the query. (default None)
75
+ - session_id (Optional[str]): Optional session identifier for caching. If None,
76
+ defaults to 'default_session'. (default None)
67
77
  - context_extension_rounds: The maximum number of rounds to extend the context with
68
78
  new triplets before halting. (default 4)
69
79
 
@@ -115,17 +125,46 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
115
125
 
116
126
  round_idx += 1
117
127
 
118
- completion = await generate_completion(
119
- query=query,
120
- context=context_text,
121
- user_prompt_path=self.user_prompt_path,
122
- system_prompt_path=self.system_prompt_path,
123
- system_prompt=self.system_prompt,
124
- )
128
+ # Check if we need to generate context summary for caching
129
+ cache_config = CacheConfig()
130
+ user = session_user.get()
131
+ user_id = getattr(user, "id", None)
132
+ session_save = user_id and cache_config.caching
133
+
134
+ if session_save:
135
+ conversation_history = await get_conversation_history(session_id=session_id)
136
+
137
+ context_summary, completion = await asyncio.gather(
138
+ summarize_text(context_text),
139
+ generate_completion(
140
+ query=query,
141
+ context=context_text,
142
+ user_prompt_path=self.user_prompt_path,
143
+ system_prompt_path=self.system_prompt_path,
144
+ system_prompt=self.system_prompt,
145
+ conversation_history=conversation_history,
146
+ ),
147
+ )
148
+ else:
149
+ completion = await generate_completion(
150
+ query=query,
151
+ context=context_text,
152
+ user_prompt_path=self.user_prompt_path,
153
+ system_prompt_path=self.system_prompt_path,
154
+ system_prompt=self.system_prompt,
155
+ )
125
156
 
126
157
  if self.save_interaction and context_text and triplets and completion:
127
158
  await self.save_qa(
128
159
  question=query, answer=completion, context=context_text, triplets=triplets
129
160
  )
130
161
 
162
+ if session_save:
163
+ await save_conversation_history(
164
+ query=query,
165
+ context_summary=context_summary,
166
+ answer=completion,
167
+ session_id=session_id,
168
+ )
169
+
131
170
  return [completion]
@@ -1,11 +1,18 @@
1
+ import asyncio
1
2
  from typing import Optional, List, Type, Any
2
3
  from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
3
4
  from cognee.shared.logging_utils import get_logger
4
5
 
5
6
  from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
6
- from cognee.modules.retrieval.utils.completion import generate_completion
7
+ from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
8
+ from cognee.modules.retrieval.utils.session_cache import (
9
+ save_conversation_history,
10
+ get_conversation_history,
11
+ )
7
12
  from cognee.infrastructure.llm.LLMGateway import LLMGateway
8
13
  from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
14
+ from cognee.context_global_variables import session_user
15
+ from cognee.infrastructure.databases.cache.config import CacheConfig
9
16
 
10
17
  logger = get_logger()
11
18
 
@@ -58,6 +65,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
58
65
  self,
59
66
  query: str,
60
67
  context: Optional[List[Edge]] = None,
68
+ session_id: Optional[str] = None,
61
69
  max_iter=4,
62
70
  ) -> List[str]:
63
71
  """
@@ -74,6 +82,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
74
82
  - query (str): The user's query to be processed and answered.
75
83
  - context (Optional[Any]): Optional context that may assist in answering the query.
76
84
  If not provided, it will be fetched based on the query. (default None)
85
+ - session_id (Optional[str]): Optional session identifier for caching. If None,
86
+ defaults to 'default_session'. (default None)
77
87
  - max_iter: The maximum number of iterations to refine the answer and generate
78
88
  follow-up questions. (default 4)
79
89
 
@@ -86,6 +96,16 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
86
96
  triplets = []
87
97
  completion = ""
88
98
 
99
+ # Retrieve conversation history if session saving is enabled
100
+ cache_config = CacheConfig()
101
+ user = session_user.get()
102
+ user_id = getattr(user, "id", None)
103
+ session_save = user_id and cache_config.caching
104
+
105
+ conversation_history = ""
106
+ if session_save:
107
+ conversation_history = await get_conversation_history(session_id=session_id)
108
+
89
109
  for round_idx in range(max_iter + 1):
90
110
  if round_idx == 0:
91
111
  if context is None:
@@ -103,6 +123,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
103
123
  user_prompt_path=self.user_prompt_path,
104
124
  system_prompt_path=self.system_prompt_path,
105
125
  system_prompt=self.system_prompt,
126
+ conversation_history=conversation_history if session_save else None,
106
127
  )
107
128
  logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
108
129
  if round_idx < max_iter:
@@ -139,4 +160,14 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
139
160
  question=query, answer=completion, context=context_text, triplets=triplets
140
161
  )
141
162
 
163
+ # Save to session cache
164
+ if session_save:
165
+ context_summary = await summarize_text(context_text)
166
+ await save_conversation_history(
167
+ query=query,
168
+ context_summary=context_summary,
169
+ answer=completion,
170
+ session_id=session_id,
171
+ )
172
+
142
173
  return [completion]