cognee 0.2.1.dev7__py3-none-any.whl → 0.2.2.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/client.py +44 -4
- cognee/api/health.py +332 -0
- cognee/api/v1/add/add.py +5 -2
- cognee/api/v1/add/routers/get_add_router.py +3 -0
- cognee/api/v1/cognify/code_graph_pipeline.py +3 -1
- cognee/api/v1/cognify/cognify.py +8 -0
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -1
- cognee/api/v1/config/config.py +3 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +2 -8
- cognee/api/v1/delete/delete.py +16 -12
- cognee/api/v1/responses/routers/get_responses_router.py +3 -1
- cognee/api/v1/search/search.py +10 -0
- cognee/api/v1/settings/routers/get_settings_router.py +0 -2
- cognee/base_config.py +1 -0
- cognee/eval_framework/evaluation/direct_llm_eval_adapter.py +5 -6
- cognee/infrastructure/databases/graph/config.py +2 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +58 -12
- cognee/infrastructure/databases/graph/graph_db_interface.py +15 -10
- cognee/infrastructure/databases/graph/kuzu/adapter.py +43 -16
- cognee/infrastructure/databases/graph/kuzu/kuzu_migrate.py +281 -0
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +151 -77
- cognee/infrastructure/databases/graph/neptune_driver/__init__.py +15 -0
- cognee/infrastructure/databases/graph/neptune_driver/adapter.py +1427 -0
- cognee/infrastructure/databases/graph/neptune_driver/exceptions.py +115 -0
- cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py +224 -0
- cognee/infrastructure/databases/graph/networkx/adapter.py +3 -3
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +449 -0
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +11 -3
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +8 -3
- cognee/infrastructure/databases/vector/create_vector_engine.py +31 -23
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +3 -1
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +21 -6
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +4 -3
- cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +3 -1
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +22 -16
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +36 -34
- cognee/infrastructure/databases/vector/vector_db_interface.py +78 -7
- cognee/infrastructure/files/utils/get_data_file_path.py +39 -0
- cognee/infrastructure/files/utils/guess_file_type.py +2 -2
- cognee/infrastructure/files/utils/open_data_file.py +4 -23
- cognee/infrastructure/llm/LLMGateway.py +137 -0
- cognee/infrastructure/llm/__init__.py +14 -4
- cognee/infrastructure/llm/config.py +29 -1
- cognee/infrastructure/llm/prompts/answer_hotpot_question.txt +1 -1
- cognee/infrastructure/llm/prompts/answer_hotpot_using_cognee_search.txt +1 -1
- cognee/infrastructure/llm/prompts/answer_simple_question.txt +1 -1
- cognee/infrastructure/llm/prompts/answer_simple_question_restricted.txt +1 -1
- cognee/infrastructure/llm/prompts/categorize_categories.txt +1 -1
- cognee/infrastructure/llm/prompts/classify_content.txt +1 -1
- cognee/infrastructure/llm/prompts/context_for_question.txt +1 -1
- cognee/infrastructure/llm/prompts/graph_context_for_question.txt +1 -1
- cognee/infrastructure/llm/prompts/natural_language_retriever_system.txt +1 -1
- cognee/infrastructure/llm/prompts/patch_gen_instructions.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +130 -0
- cognee/infrastructure/llm/prompts/summarize_code.txt +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/__init__.py +57 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/async_client.py +533 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/config.py +94 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/globals.py +37 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/inlinedbaml.py +21 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/parser.py +131 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/runtime.py +266 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/stream_types.py +137 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/sync_client.py +550 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/tracing.py +26 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/type_builder.py +962 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/type_map.py +52 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/types.py +166 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extract_categories.baml +109 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extract_content_graph.baml +343 -0
- cognee/{modules/data → infrastructure/llm/structured_output_framework/baml/baml_src}/extraction/__init__.py +1 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/extract_summary.py +89 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py +33 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/generators.baml +18 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +3 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_categories.py +12 -0
- cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/extract_summary.py +16 -7
- cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/knowledge_graph/extract_content_graph.py +7 -6
- cognee/infrastructure/llm/{anthropic → structured_output_framework/litellm_instructor/llm/anthropic}/adapter.py +10 -4
- cognee/infrastructure/llm/{gemini → structured_output_framework/litellm_instructor/llm/gemini}/adapter.py +6 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/__init__.py +0 -0
- cognee/infrastructure/llm/{generic_llm_api → structured_output_framework/litellm_instructor/llm/generic_llm_api}/adapter.py +7 -3
- cognee/infrastructure/llm/{get_llm_client.py → structured_output_framework/litellm_instructor/llm/get_llm_client.py} +18 -6
- cognee/infrastructure/llm/{llm_interface.py → structured_output_framework/litellm_instructor/llm/llm_interface.py} +2 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/__init__.py +0 -0
- cognee/infrastructure/llm/{ollama → structured_output_framework/litellm_instructor/llm/ollama}/adapter.py +4 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/__init__.py +0 -0
- cognee/infrastructure/llm/{openai → structured_output_framework/litellm_instructor/llm/openai}/adapter.py +6 -4
- cognee/infrastructure/llm/{rate_limiter.py → structured_output_framework/litellm_instructor/llm/rate_limiter.py} +0 -5
- cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +4 -2
- cognee/infrastructure/llm/tokenizer/TikToken/adapter.py +7 -3
- cognee/infrastructure/llm/tokenizer/__init__.py +4 -0
- cognee/infrastructure/llm/utils.py +3 -1
- cognee/infrastructure/loaders/LoaderEngine.py +156 -0
- cognee/infrastructure/loaders/LoaderInterface.py +73 -0
- cognee/infrastructure/loaders/__init__.py +18 -0
- cognee/infrastructure/loaders/core/__init__.py +7 -0
- cognee/infrastructure/loaders/core/audio_loader.py +98 -0
- cognee/infrastructure/loaders/core/image_loader.py +114 -0
- cognee/infrastructure/loaders/core/text_loader.py +90 -0
- cognee/infrastructure/loaders/create_loader_engine.py +32 -0
- cognee/infrastructure/loaders/external/__init__.py +22 -0
- cognee/infrastructure/loaders/external/pypdf_loader.py +96 -0
- cognee/infrastructure/loaders/external/unstructured_loader.py +127 -0
- cognee/infrastructure/loaders/get_loader_engine.py +18 -0
- cognee/infrastructure/loaders/supported_loaders.py +18 -0
- cognee/infrastructure/loaders/use_loader.py +21 -0
- cognee/infrastructure/loaders/utils/__init__.py +0 -0
- cognee/modules/data/methods/__init__.py +1 -0
- cognee/modules/data/methods/get_authorized_dataset.py +23 -0
- cognee/modules/data/models/Data.py +13 -3
- cognee/modules/data/processing/document_types/AudioDocument.py +2 -2
- cognee/modules/data/processing/document_types/ImageDocument.py +2 -2
- cognee/modules/data/processing/document_types/PdfDocument.py +4 -11
- cognee/modules/data/processing/document_types/UnstructuredDocument.py +2 -5
- cognee/modules/engine/utils/generate_edge_id.py +5 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +45 -35
- cognee/modules/graph/methods/get_formatted_graph_data.py +8 -2
- cognee/modules/graph/utils/get_graph_from_model.py +93 -101
- cognee/modules/ingestion/data_types/TextData.py +8 -2
- cognee/modules/ingestion/save_data_to_file.py +1 -1
- cognee/modules/pipelines/exceptions/__init__.py +1 -0
- cognee/modules/pipelines/exceptions/exceptions.py +12 -0
- cognee/modules/pipelines/models/DataItemStatus.py +5 -0
- cognee/modules/pipelines/models/PipelineRunInfo.py +6 -0
- cognee/modules/pipelines/models/__init__.py +1 -0
- cognee/modules/pipelines/operations/pipeline.py +10 -2
- cognee/modules/pipelines/operations/run_tasks.py +252 -20
- cognee/modules/pipelines/operations/run_tasks_distributed.py +1 -1
- cognee/modules/retrieval/chunks_retriever.py +23 -1
- cognee/modules/retrieval/code_retriever.py +66 -9
- cognee/modules/retrieval/completion_retriever.py +11 -9
- cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +0 -2
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +0 -2
- cognee/modules/retrieval/graph_completion_cot_retriever.py +8 -9
- cognee/modules/retrieval/graph_completion_retriever.py +1 -1
- cognee/modules/retrieval/insights_retriever.py +4 -0
- cognee/modules/retrieval/natural_language_retriever.py +9 -15
- cognee/modules/retrieval/summaries_retriever.py +23 -1
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +23 -4
- cognee/modules/retrieval/utils/completion.py +6 -9
- cognee/modules/retrieval/utils/description_to_codepart_search.py +2 -3
- cognee/modules/search/methods/search.py +5 -1
- cognee/modules/search/operations/__init__.py +1 -0
- cognee/modules/search/operations/select_search_type.py +42 -0
- cognee/modules/search/types/SearchType.py +1 -0
- cognee/modules/settings/get_settings.py +0 -8
- cognee/modules/settings/save_vector_db_config.py +1 -1
- cognee/shared/data_models.py +3 -1
- cognee/shared/logging_utils.py +0 -5
- cognee/tasks/chunk_naive_llm_classifier/chunk_naive_llm_classifier.py +2 -2
- cognee/tasks/documents/extract_chunks_from_documents.py +10 -12
- cognee/tasks/entity_completion/entity_extractors/llm_entity_extractor.py +4 -6
- cognee/tasks/graph/cascade_extract/utils/extract_content_nodes_and_relationship_names.py +4 -6
- cognee/tasks/graph/cascade_extract/utils/extract_edge_triplets.py +6 -7
- cognee/tasks/graph/cascade_extract/utils/extract_nodes.py +4 -7
- cognee/tasks/graph/extract_graph_from_code.py +3 -2
- cognee/tasks/graph/extract_graph_from_data.py +4 -3
- cognee/tasks/graph/infer_data_ontology.py +5 -6
- cognee/tasks/ingestion/data_item_to_text_file.py +79 -0
- cognee/tasks/ingestion/ingest_data.py +91 -61
- cognee/tasks/ingestion/resolve_data_directories.py +3 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +3 -0
- cognee/tasks/storage/index_data_points.py +1 -1
- cognee/tasks/storage/index_graph_edges.py +4 -1
- cognee/tasks/summarization/summarize_code.py +2 -3
- cognee/tasks/summarization/summarize_text.py +3 -2
- cognee/tests/test_cognee_server_start.py +12 -7
- cognee/tests/test_deduplication.py +2 -2
- cognee/tests/test_deletion.py +58 -17
- cognee/tests/test_graph_visualization_permissions.py +161 -0
- cognee/tests/test_neptune_analytics_graph.py +309 -0
- cognee/tests/test_neptune_analytics_hybrid.py +176 -0
- cognee/tests/{test_weaviate.py → test_neptune_analytics_vector.py} +86 -11
- cognee/tests/test_pgvector.py +5 -5
- cognee/tests/test_s3.py +1 -6
- cognee/tests/unit/infrastructure/databases/test_rate_limiter.py +11 -10
- cognee/tests/unit/infrastructure/databases/vector/__init__.py +0 -0
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +1 -1
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +5 -5
- cognee/tests/unit/infrastructure/test_rate_limiting_realistic.py +6 -4
- cognee/tests/unit/infrastructure/test_rate_limiting_retry.py +1 -1
- cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_test.py +61 -3
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +84 -9
- cognee/tests/unit/modules/search/search_methods_test.py +55 -0
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/METADATA +13 -9
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/RECORD +203 -164
- cognee/infrastructure/databases/vector/pinecone/adapter.py +0 -8
- cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py +0 -514
- cognee/infrastructure/databases/vector/qdrant/__init__.py +0 -2
- cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +0 -527
- cognee/infrastructure/databases/vector/weaviate_db/__init__.py +0 -1
- cognee/modules/data/extraction/extract_categories.py +0 -14
- cognee/tests/test_qdrant.py +0 -99
- distributed/Dockerfile +0 -34
- distributed/app.py +0 -4
- distributed/entrypoint.py +0 -71
- distributed/entrypoint.sh +0 -5
- distributed/modal_image.py +0 -11
- distributed/queues.py +0 -5
- distributed/tasks/queued_add_data_points.py +0 -13
- distributed/tasks/queued_add_edges.py +0 -13
- distributed/tasks/queued_add_nodes.py +0 -13
- distributed/test.py +0 -28
- distributed/utils.py +0 -19
- distributed/workers/data_point_saving_worker.py +0 -93
- distributed/workers/graph_saving_worker.py +0 -104
- /cognee/infrastructure/databases/{graph/memgraph → hybrid/neptune_analytics}/__init__.py +0 -0
- /cognee/infrastructure/{llm → databases/vector/embeddings}/embedding_rate_limiter.py +0 -0
- /cognee/infrastructure/{databases/vector/pinecone → llm/structured_output_framework}/__init__.py +0 -0
- /cognee/infrastructure/llm/{anthropic → structured_output_framework/baml/baml_src}/__init__.py +0 -0
- /cognee/infrastructure/llm/{gemini/__init__.py → structured_output_framework/baml/baml_src/extraction/extract_categories.py} +0 -0
- /cognee/infrastructure/llm/{generic_llm_api → structured_output_framework/baml/baml_src/extraction/knowledge_graph}/__init__.py +0 -0
- /cognee/infrastructure/llm/{ollama → structured_output_framework/litellm_instructor}/__init__.py +0 -0
- /cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/knowledge_graph/__init__.py +0 -0
- /cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/texts.json +0 -0
- /cognee/infrastructure/llm/{openai → structured_output_framework/litellm_instructor/llm}/__init__.py +0 -0
- {distributed → cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic}/__init__.py +0 -0
- {distributed/tasks → cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini}/__init__.py +0 -0
- /cognee/modules/data/{extraction/knowledge_graph → methods}/add_model_class_to_graph.py +0 -0
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/WHEEL +0 -0
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,21 +1,31 @@
|
|
|
1
1
|
import os
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
2
4
|
from uuid import UUID
|
|
3
|
-
from typing import Any
|
|
5
|
+
from typing import Any, List
|
|
4
6
|
from functools import wraps
|
|
7
|
+
from sqlalchemy import select
|
|
5
8
|
|
|
9
|
+
import cognee.modules.ingestion as ingestion
|
|
6
10
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
7
11
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
8
12
|
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
|
|
9
13
|
from cognee.modules.users.models import User
|
|
14
|
+
from cognee.modules.data.models import Data
|
|
15
|
+
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
|
10
16
|
from cognee.shared.logging_utils import get_logger
|
|
11
17
|
from cognee.modules.users.methods import get_default_user
|
|
12
18
|
from cognee.modules.pipelines.utils import generate_pipeline_id
|
|
19
|
+
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
|
|
20
|
+
from cognee.tasks.ingestion import save_data_item_to_storage, resolve_data_directories
|
|
13
21
|
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
|
14
22
|
PipelineRunCompleted,
|
|
15
23
|
PipelineRunErrored,
|
|
16
24
|
PipelineRunStarted,
|
|
17
25
|
PipelineRunYield,
|
|
26
|
+
PipelineRunAlreadyCompleted,
|
|
18
27
|
)
|
|
28
|
+
from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus
|
|
19
29
|
|
|
20
30
|
from cognee.modules.pipelines.operations import (
|
|
21
31
|
log_pipeline_run_start,
|
|
@@ -50,15 +60,186 @@ def override_run_tasks(new_gen):
|
|
|
50
60
|
|
|
51
61
|
@override_run_tasks(run_tasks_distributed)
|
|
52
62
|
async def run_tasks(
|
|
53
|
-
tasks:
|
|
63
|
+
tasks: List[Task],
|
|
54
64
|
dataset_id: UUID,
|
|
55
|
-
data: Any = None,
|
|
65
|
+
data: List[Any] = None,
|
|
56
66
|
user: User = None,
|
|
57
67
|
pipeline_name: str = "unknown_pipeline",
|
|
58
68
|
context: dict = None,
|
|
69
|
+
incremental_loading: bool = False,
|
|
59
70
|
):
|
|
71
|
+
async def _run_tasks_data_item_incremental(
|
|
72
|
+
data_item,
|
|
73
|
+
dataset,
|
|
74
|
+
tasks,
|
|
75
|
+
pipeline_name,
|
|
76
|
+
pipeline_id,
|
|
77
|
+
pipeline_run_id,
|
|
78
|
+
context,
|
|
79
|
+
user,
|
|
80
|
+
):
|
|
81
|
+
db_engine = get_relational_engine()
|
|
82
|
+
# If incremental_loading of data is set to True don't process documents already processed by pipeline
|
|
83
|
+
# If data is being added to Cognee for the first time calculate the id of the data
|
|
84
|
+
if not isinstance(data_item, Data):
|
|
85
|
+
file_path = await save_data_item_to_storage(data_item)
|
|
86
|
+
# Ingest data and add metadata
|
|
87
|
+
async with open_data_file(file_path) as file:
|
|
88
|
+
classified_data = ingestion.classify(file)
|
|
89
|
+
# data_id is the hash of file contents + owner id to avoid duplicate data
|
|
90
|
+
data_id = ingestion.identify(classified_data, user)
|
|
91
|
+
else:
|
|
92
|
+
# If data was already processed by Cognee get data id
|
|
93
|
+
data_id = data_item.id
|
|
94
|
+
|
|
95
|
+
# Check pipeline status, if Data already processed for pipeline before skip current processing
|
|
96
|
+
async with db_engine.get_async_session() as session:
|
|
97
|
+
data_point = (
|
|
98
|
+
await session.execute(select(Data).filter(Data.id == data_id))
|
|
99
|
+
).scalar_one_or_none()
|
|
100
|
+
if data_point:
|
|
101
|
+
if (
|
|
102
|
+
data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id))
|
|
103
|
+
== DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
|
104
|
+
):
|
|
105
|
+
yield {
|
|
106
|
+
"run_info": PipelineRunAlreadyCompleted(
|
|
107
|
+
pipeline_run_id=pipeline_run_id,
|
|
108
|
+
dataset_id=dataset.id,
|
|
109
|
+
dataset_name=dataset.name,
|
|
110
|
+
),
|
|
111
|
+
"data_id": data_id,
|
|
112
|
+
}
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
# Process data based on data_item and list of tasks
|
|
117
|
+
async for result in run_tasks_with_telemetry(
|
|
118
|
+
tasks=tasks,
|
|
119
|
+
data=[data_item],
|
|
120
|
+
user=user,
|
|
121
|
+
pipeline_name=pipeline_id,
|
|
122
|
+
context=context,
|
|
123
|
+
):
|
|
124
|
+
yield PipelineRunYield(
|
|
125
|
+
pipeline_run_id=pipeline_run_id,
|
|
126
|
+
dataset_id=dataset.id,
|
|
127
|
+
dataset_name=dataset.name,
|
|
128
|
+
payload=result,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Update pipeline status for Data element
|
|
132
|
+
async with db_engine.get_async_session() as session:
|
|
133
|
+
data_point = (
|
|
134
|
+
await session.execute(select(Data).filter(Data.id == data_id))
|
|
135
|
+
).scalar_one_or_none()
|
|
136
|
+
data_point.pipeline_status[pipeline_name] = {
|
|
137
|
+
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
|
138
|
+
}
|
|
139
|
+
await session.merge(data_point)
|
|
140
|
+
await session.commit()
|
|
141
|
+
|
|
142
|
+
yield {
|
|
143
|
+
"run_info": PipelineRunCompleted(
|
|
144
|
+
pipeline_run_id=pipeline_run_id,
|
|
145
|
+
dataset_id=dataset.id,
|
|
146
|
+
dataset_name=dataset.name,
|
|
147
|
+
),
|
|
148
|
+
"data_id": data_id,
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
except Exception as error:
|
|
152
|
+
# Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline
|
|
153
|
+
logger.error(
|
|
154
|
+
f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}."
|
|
155
|
+
)
|
|
156
|
+
yield {
|
|
157
|
+
"run_info": PipelineRunErrored(
|
|
158
|
+
pipeline_run_id=pipeline_run_id,
|
|
159
|
+
payload=repr(error),
|
|
160
|
+
dataset_id=dataset.id,
|
|
161
|
+
dataset_name=dataset.name,
|
|
162
|
+
),
|
|
163
|
+
"data_id": data_id,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
if os.getenv("RAISE_INCREMENTAL_LOADING_ERRORS", "true").lower() == "true":
|
|
167
|
+
raise error
|
|
168
|
+
|
|
169
|
+
async def _run_tasks_data_item_regular(
|
|
170
|
+
data_item,
|
|
171
|
+
dataset,
|
|
172
|
+
tasks,
|
|
173
|
+
pipeline_id,
|
|
174
|
+
pipeline_run_id,
|
|
175
|
+
context,
|
|
176
|
+
user,
|
|
177
|
+
):
|
|
178
|
+
# Process data based on data_item and list of tasks
|
|
179
|
+
async for result in run_tasks_with_telemetry(
|
|
180
|
+
tasks=tasks,
|
|
181
|
+
data=[data_item],
|
|
182
|
+
user=user,
|
|
183
|
+
pipeline_name=pipeline_id,
|
|
184
|
+
context=context,
|
|
185
|
+
):
|
|
186
|
+
yield PipelineRunYield(
|
|
187
|
+
pipeline_run_id=pipeline_run_id,
|
|
188
|
+
dataset_id=dataset.id,
|
|
189
|
+
dataset_name=dataset.name,
|
|
190
|
+
payload=result,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
yield {
|
|
194
|
+
"run_info": PipelineRunCompleted(
|
|
195
|
+
pipeline_run_id=pipeline_run_id,
|
|
196
|
+
dataset_id=dataset.id,
|
|
197
|
+
dataset_name=dataset.name,
|
|
198
|
+
)
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
async def _run_tasks_data_item(
|
|
202
|
+
data_item,
|
|
203
|
+
dataset,
|
|
204
|
+
tasks,
|
|
205
|
+
pipeline_name,
|
|
206
|
+
pipeline_id,
|
|
207
|
+
pipeline_run_id,
|
|
208
|
+
context,
|
|
209
|
+
user,
|
|
210
|
+
incremental_loading,
|
|
211
|
+
):
|
|
212
|
+
# Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped,
|
|
213
|
+
# PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues
|
|
214
|
+
result = None
|
|
215
|
+
if incremental_loading:
|
|
216
|
+
async for result in _run_tasks_data_item_incremental(
|
|
217
|
+
data_item=data_item,
|
|
218
|
+
dataset=dataset,
|
|
219
|
+
tasks=tasks,
|
|
220
|
+
pipeline_name=pipeline_name,
|
|
221
|
+
pipeline_id=pipeline_id,
|
|
222
|
+
pipeline_run_id=pipeline_run_id,
|
|
223
|
+
context=context,
|
|
224
|
+
user=user,
|
|
225
|
+
):
|
|
226
|
+
pass
|
|
227
|
+
else:
|
|
228
|
+
async for result in _run_tasks_data_item_regular(
|
|
229
|
+
data_item=data_item,
|
|
230
|
+
dataset=dataset,
|
|
231
|
+
tasks=tasks,
|
|
232
|
+
pipeline_id=pipeline_id,
|
|
233
|
+
pipeline_run_id=pipeline_run_id,
|
|
234
|
+
context=context,
|
|
235
|
+
user=user,
|
|
236
|
+
):
|
|
237
|
+
pass
|
|
238
|
+
|
|
239
|
+
return result
|
|
240
|
+
|
|
60
241
|
if not user:
|
|
61
|
-
user = get_default_user()
|
|
242
|
+
user = await get_default_user()
|
|
62
243
|
|
|
63
244
|
# Get Dataset object
|
|
64
245
|
db_engine = get_relational_engine()
|
|
@@ -68,9 +249,7 @@ async def run_tasks(
|
|
|
68
249
|
dataset = await session.get(Dataset, dataset_id)
|
|
69
250
|
|
|
70
251
|
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
|
71
|
-
|
|
72
252
|
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
|
73
|
-
|
|
74
253
|
pipeline_run_id = pipeline_run.pipeline_run_id
|
|
75
254
|
|
|
76
255
|
yield PipelineRunStarted(
|
|
@@ -81,18 +260,65 @@ async def run_tasks(
|
|
|
81
260
|
)
|
|
82
261
|
|
|
83
262
|
try:
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
263
|
+
if not isinstance(data, list):
|
|
264
|
+
data = [data]
|
|
265
|
+
|
|
266
|
+
if incremental_loading:
|
|
267
|
+
data = await resolve_data_directories(data)
|
|
268
|
+
|
|
269
|
+
# TODO: Return to using async.gather for data items after Cognee release
|
|
270
|
+
# # Create async tasks per data item that will run the pipeline for the data item
|
|
271
|
+
# data_item_tasks = [
|
|
272
|
+
# asyncio.create_task(
|
|
273
|
+
# _run_tasks_data_item(
|
|
274
|
+
# data_item,
|
|
275
|
+
# dataset,
|
|
276
|
+
# tasks,
|
|
277
|
+
# pipeline_name,
|
|
278
|
+
# pipeline_id,
|
|
279
|
+
# pipeline_run_id,
|
|
280
|
+
# context,
|
|
281
|
+
# user,
|
|
282
|
+
# incremental_loading,
|
|
283
|
+
# )
|
|
284
|
+
# )
|
|
285
|
+
# for data_item in data
|
|
286
|
+
# ]
|
|
287
|
+
# results = await asyncio.gather(*data_item_tasks)
|
|
288
|
+
# # Remove skipped data items from results
|
|
289
|
+
# results = [result for result in results if result]
|
|
290
|
+
|
|
291
|
+
### TEMP sync data item handling
|
|
292
|
+
results = []
|
|
293
|
+
# Run the pipeline for each data_item sequentially, one after the other
|
|
294
|
+
for data_item in data:
|
|
295
|
+
result = await _run_tasks_data_item(
|
|
296
|
+
data_item,
|
|
297
|
+
dataset,
|
|
298
|
+
tasks,
|
|
299
|
+
pipeline_name,
|
|
300
|
+
pipeline_id,
|
|
301
|
+
pipeline_run_id,
|
|
302
|
+
context,
|
|
303
|
+
user,
|
|
304
|
+
incremental_loading,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Skip items that returned a false-y value
|
|
308
|
+
if result:
|
|
309
|
+
results.append(result)
|
|
310
|
+
### END
|
|
311
|
+
|
|
312
|
+
# Remove skipped data items from results
|
|
313
|
+
results = [result for result in results if result]
|
|
314
|
+
|
|
315
|
+
# If any data item could not be processed propagate error
|
|
316
|
+
errored_results = [
|
|
317
|
+
result for result in results if isinstance(result["run_info"], PipelineRunErrored)
|
|
318
|
+
]
|
|
319
|
+
if errored_results:
|
|
320
|
+
raise PipelineRunFailedError(
|
|
321
|
+
message="Pipeline run failed. Data item could not be processed."
|
|
96
322
|
)
|
|
97
323
|
|
|
98
324
|
await log_pipeline_run_complete(
|
|
@@ -103,6 +329,7 @@ async def run_tasks(
|
|
|
103
329
|
pipeline_run_id=pipeline_run_id,
|
|
104
330
|
dataset_id=dataset.id,
|
|
105
331
|
dataset_name=dataset.name,
|
|
332
|
+
data_ingestion_info=results,
|
|
106
333
|
)
|
|
107
334
|
|
|
108
335
|
graph_engine = await get_graph_engine()
|
|
@@ -120,9 +347,14 @@ async def run_tasks(
|
|
|
120
347
|
|
|
121
348
|
yield PipelineRunErrored(
|
|
122
349
|
pipeline_run_id=pipeline_run_id,
|
|
123
|
-
payload=error,
|
|
350
|
+
payload=repr(error),
|
|
124
351
|
dataset_id=dataset.id,
|
|
125
352
|
dataset_name=dataset.name,
|
|
353
|
+
data_ingestion_info=locals().get(
|
|
354
|
+
"results"
|
|
355
|
+
), # Returns results if they exist or returns None
|
|
126
356
|
)
|
|
127
357
|
|
|
128
|
-
raise error
|
|
358
|
+
# In case of error during incremental loading of data just let the user know the pipeline Errored, don't raise error
|
|
359
|
+
if not isinstance(error, PipelineRunFailedError):
|
|
360
|
+
raise error
|
|
@@ -44,7 +44,7 @@ if modal:
|
|
|
44
44
|
|
|
45
45
|
async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, context):
|
|
46
46
|
if not user:
|
|
47
|
-
user = get_default_user()
|
|
47
|
+
user = await get_default_user()
|
|
48
48
|
|
|
49
49
|
db_engine = get_relational_engine()
|
|
50
50
|
async with db_engine.get_async_session() as session:
|
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
from typing import Any, Optional
|
|
2
2
|
|
|
3
|
+
from cognee.shared.logging_utils import get_logger
|
|
3
4
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
4
5
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
5
6
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
6
7
|
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
|
7
8
|
|
|
9
|
+
logger = get_logger("ChunksRetriever")
|
|
10
|
+
|
|
8
11
|
|
|
9
12
|
class ChunksRetriever(BaseRetriever):
|
|
10
13
|
"""
|
|
@@ -41,14 +44,22 @@ class ChunksRetriever(BaseRetriever):
|
|
|
41
44
|
|
|
42
45
|
- Any: A list of document chunk payloads retrieved from the search.
|
|
43
46
|
"""
|
|
47
|
+
logger.info(
|
|
48
|
+
f"Starting chunk retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
|
49
|
+
)
|
|
50
|
+
|
|
44
51
|
vector_engine = get_vector_engine()
|
|
45
52
|
|
|
46
53
|
try:
|
|
47
54
|
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
|
55
|
+
logger.info(f"Found {len(found_chunks)} chunks from vector search")
|
|
48
56
|
except CollectionNotFoundError as error:
|
|
57
|
+
logger.error("DocumentChunk_text collection not found in vector database")
|
|
49
58
|
raise NoDataError("No data found in the system, please add data first.") from error
|
|
50
59
|
|
|
51
|
-
|
|
60
|
+
chunk_payloads = [result.payload for result in found_chunks]
|
|
61
|
+
logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
|
|
62
|
+
return chunk_payloads
|
|
52
63
|
|
|
53
64
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
54
65
|
"""
|
|
@@ -70,6 +81,17 @@ class ChunksRetriever(BaseRetriever):
|
|
|
70
81
|
- Any: The context used for the completion or the retrieved context if none was
|
|
71
82
|
provided.
|
|
72
83
|
"""
|
|
84
|
+
logger.info(
|
|
85
|
+
f"Starting completion generation for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
|
86
|
+
)
|
|
87
|
+
|
|
73
88
|
if context is None:
|
|
89
|
+
logger.debug("No context provided, retrieving context from vector database")
|
|
74
90
|
context = await self.get_context(query)
|
|
91
|
+
else:
|
|
92
|
+
logger.debug("Using provided context")
|
|
93
|
+
|
|
94
|
+
logger.info(
|
|
95
|
+
f"Returning context with {len(context) if isinstance(context, list) else 1} item(s)"
|
|
96
|
+
)
|
|
75
97
|
return context
|
|
@@ -3,11 +3,13 @@ import asyncio
|
|
|
3
3
|
import aiofiles
|
|
4
4
|
from pydantic import BaseModel
|
|
5
5
|
|
|
6
|
+
from cognee.shared.logging_utils import get_logger
|
|
6
7
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
7
8
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
8
9
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
9
|
-
from cognee.infrastructure.llm.
|
|
10
|
-
|
|
10
|
+
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
11
|
+
|
|
12
|
+
logger = get_logger("CodeRetriever")
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
class CodeRetriever(BaseRetriever):
|
|
@@ -35,26 +37,42 @@ class CodeRetriever(BaseRetriever):
|
|
|
35
37
|
|
|
36
38
|
async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
|
|
37
39
|
"""Process the query using LLM to extract file names and source code parts."""
|
|
38
|
-
|
|
39
|
-
|
|
40
|
+
logger.debug(
|
|
41
|
+
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
system_prompt = LLMGateway.read_query_prompt("codegraph_retriever_system.txt")
|
|
45
|
+
|
|
40
46
|
try:
|
|
41
|
-
|
|
47
|
+
result = await LLMGateway.acreate_structured_output(
|
|
42
48
|
text_input=query,
|
|
43
49
|
system_prompt=system_prompt,
|
|
44
50
|
response_model=self.CodeQueryInfo,
|
|
45
51
|
)
|
|
52
|
+
logger.info(
|
|
53
|
+
f"LLM extracted {len(result.filenames)} filenames and {len(result.sourcecode)} chars of source code"
|
|
54
|
+
)
|
|
55
|
+
return result
|
|
46
56
|
except Exception as e:
|
|
57
|
+
logger.error(f"Failed to retrieve structured output from LLM: {str(e)}")
|
|
47
58
|
raise RuntimeError("Failed to retrieve structured output from LLM") from e
|
|
48
59
|
|
|
49
60
|
async def get_context(self, query: str) -> Any:
|
|
50
61
|
"""Find relevant code files based on the query."""
|
|
62
|
+
logger.info(
|
|
63
|
+
f"Starting code retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
|
64
|
+
)
|
|
65
|
+
|
|
51
66
|
if not query or not isinstance(query, str):
|
|
67
|
+
logger.error("Invalid query: must be a non-empty string")
|
|
52
68
|
raise ValueError("The query must be a non-empty string.")
|
|
53
69
|
|
|
54
70
|
try:
|
|
55
71
|
vector_engine = get_vector_engine()
|
|
56
72
|
graph_engine = await get_graph_engine()
|
|
73
|
+
logger.debug("Successfully initialized vector and graph engines")
|
|
57
74
|
except Exception as e:
|
|
75
|
+
logger.error(f"Database initialization error: {str(e)}")
|
|
58
76
|
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
|
|
59
77
|
|
|
60
78
|
files_and_codeparts = await self._process_query(query)
|
|
@@ -63,52 +81,80 @@ class CodeRetriever(BaseRetriever):
|
|
|
63
81
|
similar_codepieces = []
|
|
64
82
|
|
|
65
83
|
if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
|
|
84
|
+
logger.info("No specific files/code extracted from query, performing general search")
|
|
85
|
+
|
|
66
86
|
for collection in self.file_name_collections:
|
|
87
|
+
logger.debug(f"Searching {collection} collection with general query")
|
|
67
88
|
search_results_file = await vector_engine.search(
|
|
68
89
|
collection, query, limit=self.top_k
|
|
69
90
|
)
|
|
91
|
+
logger.debug(f"Found {len(search_results_file)} results in {collection}")
|
|
70
92
|
for res in search_results_file:
|
|
71
93
|
similar_filenames.append(
|
|
72
94
|
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
73
95
|
)
|
|
74
96
|
|
|
75
97
|
for collection in self.classes_and_functions_collections:
|
|
98
|
+
logger.debug(f"Searching {collection} collection with general query")
|
|
76
99
|
search_results_code = await vector_engine.search(
|
|
77
100
|
collection, query, limit=self.top_k
|
|
78
101
|
)
|
|
102
|
+
logger.debug(f"Found {len(search_results_code)} results in {collection}")
|
|
79
103
|
for res in search_results_code:
|
|
80
104
|
similar_codepieces.append(
|
|
81
105
|
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
82
106
|
)
|
|
83
107
|
else:
|
|
108
|
+
logger.info(
|
|
109
|
+
f"Using extracted filenames ({len(files_and_codeparts.filenames)}) and source code for targeted search"
|
|
110
|
+
)
|
|
111
|
+
|
|
84
112
|
for collection in self.file_name_collections:
|
|
85
113
|
for file_from_query in files_and_codeparts.filenames:
|
|
114
|
+
logger.debug(f"Searching {collection} for specific file: {file_from_query}")
|
|
86
115
|
search_results_file = await vector_engine.search(
|
|
87
116
|
collection, file_from_query, limit=self.top_k
|
|
88
117
|
)
|
|
118
|
+
logger.debug(
|
|
119
|
+
f"Found {len(search_results_file)} results for file {file_from_query}"
|
|
120
|
+
)
|
|
89
121
|
for res in search_results_file:
|
|
90
122
|
similar_filenames.append(
|
|
91
123
|
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
92
124
|
)
|
|
93
125
|
|
|
94
126
|
for collection in self.classes_and_functions_collections:
|
|
127
|
+
logger.debug(f"Searching {collection} with extracted source code")
|
|
95
128
|
search_results_code = await vector_engine.search(
|
|
96
129
|
collection, files_and_codeparts.sourcecode, limit=self.top_k
|
|
97
130
|
)
|
|
131
|
+
logger.debug(f"Found {len(search_results_code)} results for source code search")
|
|
98
132
|
for res in search_results_code:
|
|
99
133
|
similar_codepieces.append(
|
|
100
134
|
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
101
135
|
)
|
|
102
136
|
|
|
137
|
+
total_items = len(similar_filenames) + len(similar_codepieces)
|
|
138
|
+
logger.info(
|
|
139
|
+
f"Total search results: {total_items} items ({len(similar_filenames)} filenames, {len(similar_codepieces)} code pieces)"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
if total_items == 0:
|
|
143
|
+
logger.warning("No search results found, returning empty list")
|
|
144
|
+
return []
|
|
145
|
+
|
|
146
|
+
logger.debug("Getting graph connections for all search results")
|
|
103
147
|
relevant_triplets = await asyncio.gather(
|
|
104
148
|
*[
|
|
105
149
|
graph_engine.get_connections(similar_piece["id"])
|
|
106
150
|
for similar_piece in similar_filenames + similar_codepieces
|
|
107
151
|
]
|
|
108
152
|
)
|
|
153
|
+
logger.info(f"Retrieved graph connections for {len(relevant_triplets)} items")
|
|
109
154
|
|
|
110
155
|
paths = set()
|
|
111
|
-
for sublist in relevant_triplets:
|
|
156
|
+
for i, sublist in enumerate(relevant_triplets):
|
|
157
|
+
logger.debug(f"Processing connections for item {i}: {len(sublist)} connections")
|
|
112
158
|
for tpl in sublist:
|
|
113
159
|
if isinstance(tpl, tuple) and len(tpl) >= 3:
|
|
114
160
|
if "file_path" in tpl[0]:
|
|
@@ -116,23 +162,31 @@ class CodeRetriever(BaseRetriever):
|
|
|
116
162
|
if "file_path" in tpl[2]:
|
|
117
163
|
paths.add(tpl[2]["file_path"])
|
|
118
164
|
|
|
165
|
+
logger.info(f"Found {len(paths)} unique file paths to read")
|
|
166
|
+
|
|
119
167
|
retrieved_files = {}
|
|
120
168
|
read_tasks = []
|
|
121
169
|
for file_path in paths:
|
|
122
170
|
|
|
123
171
|
async def read_file(fp):
|
|
124
172
|
try:
|
|
173
|
+
logger.debug(f"Reading file: {fp}")
|
|
125
174
|
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
|
|
126
|
-
|
|
175
|
+
content = await f.read()
|
|
176
|
+
retrieved_files[fp] = content
|
|
177
|
+
logger.debug(f"Successfully read {len(content)} characters from {fp}")
|
|
127
178
|
except Exception as e:
|
|
128
|
-
|
|
179
|
+
logger.error(f"Error reading {fp}: {e}")
|
|
129
180
|
retrieved_files[fp] = ""
|
|
130
181
|
|
|
131
182
|
read_tasks.append(read_file(file_path))
|
|
132
183
|
|
|
133
184
|
await asyncio.gather(*read_tasks)
|
|
185
|
+
logger.info(
|
|
186
|
+
f"Successfully read {len([f for f in retrieved_files.values() if f])} files (out of {len(paths)} total)"
|
|
187
|
+
)
|
|
134
188
|
|
|
135
|
-
|
|
189
|
+
result = [
|
|
136
190
|
{
|
|
137
191
|
"name": file_path,
|
|
138
192
|
"description": file_path,
|
|
@@ -141,6 +195,9 @@ class CodeRetriever(BaseRetriever):
|
|
|
141
195
|
for file_path in paths
|
|
142
196
|
]
|
|
143
197
|
|
|
198
|
+
logger.info(f"Returning {len(result)} code file contexts")
|
|
199
|
+
return result
|
|
200
|
+
|
|
144
201
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
145
202
|
"""Returns the code files context."""
|
|
146
203
|
if context is None:
|
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
from typing import Any, Optional
|
|
2
2
|
|
|
3
|
+
from cognee.shared.logging_utils import get_logger
|
|
3
4
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
4
5
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
5
6
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
6
7
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
7
8
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
8
9
|
|
|
10
|
+
logger = get_logger("CompletionRetriever")
|
|
11
|
+
|
|
9
12
|
|
|
10
13
|
class CompletionRetriever(BaseRetriever):
|
|
11
14
|
"""
|
|
@@ -56,8 +59,10 @@ class CompletionRetriever(BaseRetriever):
|
|
|
56
59
|
|
|
57
60
|
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
|
|
58
61
|
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
|
|
59
|
-
|
|
62
|
+
combined_context = "\n".join(chunks_payload)
|
|
63
|
+
return combined_context
|
|
60
64
|
except CollectionNotFoundError as error:
|
|
65
|
+
logger.error("DocumentChunk_text collection not found")
|
|
61
66
|
raise NoDataError("No data found in the system, please add data first.") from error
|
|
62
67
|
|
|
63
68
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
@@ -70,22 +75,19 @@ class CompletionRetriever(BaseRetriever):
|
|
|
70
75
|
Parameters:
|
|
71
76
|
-----------
|
|
72
77
|
|
|
73
|
-
- query (str): The
|
|
74
|
-
- context (Optional[Any]): Optional context to use for generating the
|
|
75
|
-
|
|
78
|
+
- query (str): The query string to be used for generating a completion.
|
|
79
|
+
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
|
80
|
+
completion; if None, it retrieves the context for the query. (default None)
|
|
76
81
|
|
|
77
82
|
Returns:
|
|
78
83
|
--------
|
|
79
84
|
|
|
80
|
-
- Any:
|
|
85
|
+
- Any: The generated completion based on the provided query and context.
|
|
81
86
|
"""
|
|
82
87
|
if context is None:
|
|
83
88
|
context = await self.get_context(query)
|
|
84
89
|
|
|
85
90
|
completion = await generate_completion(
|
|
86
|
-
query
|
|
87
|
-
context=context,
|
|
88
|
-
user_prompt_path=self.user_prompt_path,
|
|
89
|
-
system_prompt_path=self.system_prompt_path,
|
|
91
|
+
query, context, self.user_prompt_path, self.system_prompt_path
|
|
90
92
|
)
|
|
91
93
|
return [completion]
|
|
@@ -4,8 +4,6 @@ import asyncio
|
|
|
4
4
|
from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider
|
|
5
5
|
from cognee.infrastructure.engine import DataPoint
|
|
6
6
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
|
7
|
-
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
|
8
|
-
from cognee.infrastructure.llm.prompts import read_query_prompt
|
|
9
7
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
|
10
8
|
brute_force_triplet_search,
|
|
11
9
|
format_triplets,
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
from typing import Any, Optional, List, Type
|
|
2
2
|
from cognee.shared.logging_utils import get_logger
|
|
3
|
-
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
|
4
3
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
5
4
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
6
|
-
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
|
7
5
|
|
|
8
6
|
logger = get_logger()
|
|
9
7
|
|