cognee 0.2.3.dev1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/__main__.py +4 -0
- cognee/api/client.py +28 -3
- cognee/api/health.py +10 -13
- cognee/api/v1/add/add.py +20 -6
- cognee/api/v1/add/routers/get_add_router.py +12 -37
- cognee/api/v1/cloud/routers/__init__.py +1 -0
- cognee/api/v1/cloud/routers/get_checks_router.py +23 -0
- cognee/api/v1/cognify/code_graph_pipeline.py +14 -3
- cognee/api/v1/cognify/cognify.py +67 -105
- cognee/api/v1/cognify/routers/get_cognify_router.py +11 -3
- cognee/api/v1/datasets/routers/get_datasets_router.py +16 -5
- cognee/api/v1/memify/routers/__init__.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +100 -0
- cognee/api/v1/notebooks/routers/__init__.py +1 -0
- cognee/api/v1/notebooks/routers/get_notebooks_router.py +96 -0
- cognee/api/v1/responses/default_tools.py +4 -0
- cognee/api/v1/responses/dispatch_function.py +6 -1
- cognee/api/v1/responses/models.py +1 -1
- cognee/api/v1/search/routers/get_search_router.py +20 -1
- cognee/api/v1/search/search.py +17 -4
- cognee/api/v1/sync/__init__.py +17 -0
- cognee/api/v1/sync/routers/__init__.py +3 -0
- cognee/api/v1/sync/routers/get_sync_router.py +241 -0
- cognee/api/v1/sync/sync.py +877 -0
- cognee/api/v1/ui/__init__.py +1 -0
- cognee/api/v1/ui/ui.py +529 -0
- cognee/api/v1/users/routers/get_auth_router.py +13 -1
- cognee/base_config.py +10 -1
- cognee/cli/__init__.py +10 -0
- cognee/cli/_cognee.py +273 -0
- cognee/cli/commands/__init__.py +1 -0
- cognee/cli/commands/add_command.py +80 -0
- cognee/cli/commands/cognify_command.py +128 -0
- cognee/cli/commands/config_command.py +225 -0
- cognee/cli/commands/delete_command.py +80 -0
- cognee/cli/commands/search_command.py +149 -0
- cognee/cli/config.py +33 -0
- cognee/cli/debug.py +21 -0
- cognee/cli/echo.py +45 -0
- cognee/cli/exceptions.py +23 -0
- cognee/cli/minimal_cli.py +97 -0
- cognee/cli/reference.py +26 -0
- cognee/cli/suppress_logging.py +12 -0
- cognee/eval_framework/corpus_builder/corpus_builder_executor.py +2 -2
- cognee/eval_framework/eval_config.py +1 -1
- cognee/infrastructure/databases/graph/config.py +10 -4
- cognee/infrastructure/databases/graph/get_graph_engine.py +4 -9
- cognee/infrastructure/databases/graph/kuzu/adapter.py +199 -2
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +138 -0
- cognee/infrastructure/databases/relational/__init__.py +2 -0
- cognee/infrastructure/databases/relational/get_async_session.py +15 -0
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +6 -1
- cognee/infrastructure/databases/relational/with_async_session.py +25 -0
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +1 -1
- cognee/infrastructure/databases/vector/config.py +13 -6
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +6 -4
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +16 -7
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +5 -5
- cognee/infrastructure/databases/vector/embeddings/config.py +2 -2
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +2 -6
- cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +10 -7
- cognee/infrastructure/files/storage/LocalFileStorage.py +9 -0
- cognee/infrastructure/files/storage/S3FileStorage.py +5 -0
- cognee/infrastructure/files/storage/StorageManager.py +7 -1
- cognee/infrastructure/files/storage/storage.py +16 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +14 -9
- cognee/infrastructure/files/utils/get_file_metadata.py +2 -1
- cognee/infrastructure/llm/LLMGateway.py +32 -5
- cognee/infrastructure/llm/config.py +6 -4
- cognee/infrastructure/llm/prompts/extract_query_time.txt +15 -0
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +25 -0
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +30 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py +16 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +2 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_event_entities.py +44 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/__init__.py +1 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_content_graph.py +19 -15
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_event_graph.py +46 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +3 -3
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +3 -3
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +14 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +6 -4
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +28 -4
- cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +2 -2
- cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py +3 -3
- cognee/infrastructure/llm/tokenizer/Mistral/adapter.py +3 -3
- cognee/infrastructure/llm/tokenizer/TikToken/adapter.py +6 -6
- cognee/infrastructure/llm/utils.py +7 -7
- cognee/infrastructure/utils/run_sync.py +8 -1
- cognee/modules/chunking/models/DocumentChunk.py +4 -3
- cognee/modules/cloud/exceptions/CloudApiKeyMissingError.py +15 -0
- cognee/modules/cloud/exceptions/CloudConnectionError.py +15 -0
- cognee/modules/cloud/exceptions/__init__.py +2 -0
- cognee/modules/cloud/operations/__init__.py +1 -0
- cognee/modules/cloud/operations/check_api_key.py +25 -0
- cognee/modules/data/deletion/prune_system.py +1 -1
- cognee/modules/data/methods/__init__.py +2 -0
- cognee/modules/data/methods/check_dataset_name.py +1 -1
- cognee/modules/data/methods/create_authorized_dataset.py +19 -0
- cognee/modules/data/methods/get_authorized_dataset.py +11 -5
- cognee/modules/data/methods/get_authorized_dataset_by_name.py +16 -0
- cognee/modules/data/methods/get_dataset_data.py +1 -1
- cognee/modules/data/methods/load_or_create_datasets.py +2 -20
- cognee/modules/engine/models/Event.py +16 -0
- cognee/modules/engine/models/Interval.py +8 -0
- cognee/modules/engine/models/Timestamp.py +13 -0
- cognee/modules/engine/models/__init__.py +3 -0
- cognee/modules/engine/utils/__init__.py +2 -0
- cognee/modules/engine/utils/generate_event_datapoint.py +46 -0
- cognee/modules/engine/utils/generate_timestamp_datapoint.py +51 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +2 -2
- cognee/modules/graph/methods/get_formatted_graph_data.py +3 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/resolve_edges_to_text.py +71 -0
- cognee/modules/memify/__init__.py +1 -0
- cognee/modules/memify/memify.py +118 -0
- cognee/modules/notebooks/methods/__init__.py +5 -0
- cognee/modules/notebooks/methods/create_notebook.py +26 -0
- cognee/modules/notebooks/methods/delete_notebook.py +13 -0
- cognee/modules/notebooks/methods/get_notebook.py +21 -0
- cognee/modules/notebooks/methods/get_notebooks.py +18 -0
- cognee/modules/notebooks/methods/update_notebook.py +17 -0
- cognee/modules/notebooks/models/Notebook.py +53 -0
- cognee/modules/notebooks/models/__init__.py +1 -0
- cognee/modules/notebooks/operations/__init__.py +1 -0
- cognee/modules/notebooks/operations/run_in_local_sandbox.py +55 -0
- cognee/modules/pipelines/__init__.py +1 -1
- cognee/modules/pipelines/exceptions/tasks.py +18 -0
- cognee/modules/pipelines/layers/__init__.py +1 -0
- cognee/modules/pipelines/layers/check_pipeline_run_qualification.py +59 -0
- cognee/modules/pipelines/layers/pipeline_execution_mode.py +127 -0
- cognee/modules/pipelines/layers/reset_dataset_pipeline_run_status.py +28 -0
- cognee/modules/pipelines/layers/resolve_authorized_user_dataset.py +34 -0
- cognee/modules/pipelines/layers/resolve_authorized_user_datasets.py +55 -0
- cognee/modules/pipelines/layers/setup_and_check_environment.py +41 -0
- cognee/modules/pipelines/layers/validate_pipeline_tasks.py +20 -0
- cognee/modules/pipelines/methods/__init__.py +2 -0
- cognee/modules/pipelines/methods/get_pipeline_runs_by_dataset.py +34 -0
- cognee/modules/pipelines/methods/reset_pipeline_run_status.py +16 -0
- cognee/modules/pipelines/operations/__init__.py +0 -1
- cognee/modules/pipelines/operations/log_pipeline_run_initiated.py +1 -1
- cognee/modules/pipelines/operations/pipeline.py +24 -138
- cognee/modules/pipelines/operations/run_tasks.py +17 -41
- cognee/modules/retrieval/base_feedback.py +11 -0
- cognee/modules/retrieval/base_graph_retriever.py +18 -0
- cognee/modules/retrieval/base_retriever.py +1 -1
- cognee/modules/retrieval/code_retriever.py +8 -0
- cognee/modules/retrieval/coding_rules_retriever.py +31 -0
- cognee/modules/retrieval/completion_retriever.py +9 -3
- cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +1 -0
- cognee/modules/retrieval/cypher_search_retriever.py +1 -9
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +29 -13
- cognee/modules/retrieval/graph_completion_cot_retriever.py +30 -13
- cognee/modules/retrieval/graph_completion_retriever.py +107 -56
- cognee/modules/retrieval/graph_summary_completion_retriever.py +5 -1
- cognee/modules/retrieval/insights_retriever.py +14 -3
- cognee/modules/retrieval/natural_language_retriever.py +0 -4
- cognee/modules/retrieval/summaries_retriever.py +1 -1
- cognee/modules/retrieval/temporal_retriever.py +152 -0
- cognee/modules/retrieval/user_qa_feedback.py +83 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +7 -32
- cognee/modules/retrieval/utils/completion.py +10 -3
- cognee/modules/retrieval/utils/extract_uuid_from_node.py +18 -0
- cognee/modules/retrieval/utils/models.py +40 -0
- cognee/modules/search/methods/get_search_type_tools.py +168 -0
- cognee/modules/search/methods/no_access_control_search.py +47 -0
- cognee/modules/search/methods/search.py +239 -118
- cognee/modules/search/types/SearchResult.py +21 -0
- cognee/modules/search/types/SearchType.py +3 -0
- cognee/modules/search/types/__init__.py +1 -0
- cognee/modules/search/utils/__init__.py +2 -0
- cognee/modules/search/utils/prepare_search_result.py +41 -0
- cognee/modules/search/utils/transform_context_to_graph.py +38 -0
- cognee/modules/settings/get_settings.py +2 -2
- cognee/modules/sync/__init__.py +1 -0
- cognee/modules/sync/methods/__init__.py +23 -0
- cognee/modules/sync/methods/create_sync_operation.py +53 -0
- cognee/modules/sync/methods/get_sync_operation.py +107 -0
- cognee/modules/sync/methods/update_sync_operation.py +248 -0
- cognee/modules/sync/models/SyncOperation.py +142 -0
- cognee/modules/sync/models/__init__.py +3 -0
- cognee/modules/users/__init__.py +0 -1
- cognee/modules/users/methods/__init__.py +4 -1
- cognee/modules/users/methods/create_user.py +26 -1
- cognee/modules/users/methods/get_authenticated_user.py +36 -42
- cognee/modules/users/methods/get_default_user.py +3 -1
- cognee/modules/users/permissions/methods/get_specific_user_permission_datasets.py +2 -1
- cognee/root_dir.py +19 -0
- cognee/shared/CodeGraphEntities.py +1 -0
- cognee/shared/logging_utils.py +143 -32
- cognee/shared/utils.py +0 -1
- cognee/tasks/codingagents/coding_rule_associations.py +127 -0
- cognee/tasks/graph/extract_graph_from_data.py +6 -2
- cognee/tasks/ingestion/save_data_item_to_storage.py +23 -0
- cognee/tasks/memify/__init__.py +2 -0
- cognee/tasks/memify/extract_subgraph.py +7 -0
- cognee/tasks/memify/extract_subgraph_chunks.py +11 -0
- cognee/tasks/repo_processor/get_local_dependencies.py +2 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +144 -47
- cognee/tasks/storage/add_data_points.py +33 -3
- cognee/tasks/temporal_graph/__init__.py +1 -0
- cognee/tasks/temporal_graph/add_entities_to_event.py +85 -0
- cognee/tasks/temporal_graph/enrich_events.py +34 -0
- cognee/tasks/temporal_graph/extract_events_and_entities.py +32 -0
- cognee/tasks/temporal_graph/extract_knowledge_graph_from_events.py +41 -0
- cognee/tasks/temporal_graph/models.py +49 -0
- cognee/tests/integration/cli/__init__.py +3 -0
- cognee/tests/integration/cli/test_cli_integration.py +331 -0
- cognee/tests/integration/documents/PdfDocument_test.py +2 -2
- cognee/tests/integration/documents/TextDocument_test.py +2 -4
- cognee/tests/integration/documents/UnstructuredDocument_test.py +5 -8
- cognee/tests/{test_deletion.py → test_delete_hard.py} +0 -37
- cognee/tests/test_delete_soft.py +85 -0
- cognee/tests/test_kuzu.py +2 -2
- cognee/tests/test_neo4j.py +2 -2
- cognee/tests/test_permissions.py +3 -3
- cognee/tests/test_relational_db_migration.py +7 -5
- cognee/tests/test_search_db.py +136 -23
- cognee/tests/test_temporal_graph.py +167 -0
- cognee/tests/unit/api/__init__.py +1 -0
- cognee/tests/unit/api/test_conditional_authentication_endpoints.py +246 -0
- cognee/tests/unit/cli/__init__.py +3 -0
- cognee/tests/unit/cli/test_cli_commands.py +483 -0
- cognee/tests/unit/cli/test_cli_edge_cases.py +625 -0
- cognee/tests/unit/cli/test_cli_main.py +173 -0
- cognee/tests/unit/cli/test_cli_runner.py +62 -0
- cognee/tests/unit/cli/test_cli_utils.py +127 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +12 -15
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +10 -15
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +4 -3
- cognee/tests/unit/modules/retrieval/insights_retriever_test.py +4 -2
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +225 -0
- cognee/tests/unit/modules/users/__init__.py +1 -0
- cognee/tests/unit/modules/users/test_conditional_authentication.py +277 -0
- cognee/tests/unit/processing/utils/utils_test.py +20 -1
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/METADATA +13 -9
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/RECORD +247 -135
- cognee-0.3.0.dist-info/entry_points.txt +2 -0
- cognee/infrastructure/databases/graph/networkx/adapter.py +0 -1017
- cognee/infrastructure/pipeline/models/Operation.py +0 -60
- cognee/notebooks/github_analysis_step_by_step.ipynb +0 -37
- cognee/tests/tasks/descriptive_metrics/networkx_metrics_test.py +0 -7
- cognee/tests/unit/modules/search/search_methods_test.py +0 -223
- /cognee/{infrastructure/databases/graph/networkx → api/v1/memify}/__init__.py +0 -0
- /cognee/{infrastructure/pipeline/models → tasks/codingagents}/__init__.py +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/WHEEL +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -2,41 +2,23 @@ import asyncio
|
|
|
2
2
|
from uuid import UUID
|
|
3
3
|
from typing import Union
|
|
4
4
|
|
|
5
|
+
from cognee.modules.pipelines.layers.setup_and_check_environment import (
|
|
6
|
+
setup_and_check_environment,
|
|
7
|
+
)
|
|
8
|
+
|
|
5
9
|
from cognee.shared.logging_utils import get_logger
|
|
6
10
|
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
|
7
11
|
from cognee.modules.data.models import Data, Dataset
|
|
8
12
|
from cognee.modules.pipelines.operations.run_tasks import run_tasks
|
|
9
|
-
from cognee.modules.pipelines.
|
|
10
|
-
from cognee.modules.pipelines.utils import generate_pipeline_id
|
|
11
|
-
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
|
12
|
-
from cognee.modules.pipelines.methods import get_pipeline_run_by_dataset
|
|
13
|
-
|
|
13
|
+
from cognee.modules.pipelines.layers import validate_pipeline_tasks
|
|
14
14
|
from cognee.modules.pipelines.tasks.task import Task
|
|
15
|
-
from cognee.modules.users.methods import get_default_user
|
|
16
15
|
from cognee.modules.users.models import User
|
|
17
|
-
from cognee.modules.pipelines.operations import log_pipeline_run_initiated
|
|
18
16
|
from cognee.context_global_variables import set_database_global_context_variables
|
|
19
|
-
from cognee.modules.
|
|
20
|
-
|
|
21
|
-
get_authorized_existing_datasets,
|
|
22
|
-
load_or_create_datasets,
|
|
23
|
-
check_dataset_name,
|
|
17
|
+
from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
|
18
|
+
resolve_authorized_user_datasets,
|
|
24
19
|
)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
PipelineRunCompleted,
|
|
28
|
-
PipelineRunStarted,
|
|
29
|
-
)
|
|
30
|
-
|
|
31
|
-
from cognee.infrastructure.databases.relational import (
|
|
32
|
-
create_db_and_tables as create_relational_db_and_tables,
|
|
33
|
-
)
|
|
34
|
-
from cognee.infrastructure.databases.vector.pgvector import (
|
|
35
|
-
create_db_and_tables as create_pgvector_db_and_tables,
|
|
36
|
-
)
|
|
37
|
-
from cognee.context_global_variables import (
|
|
38
|
-
graph_db_config as context_graph_db_config,
|
|
39
|
-
vector_db_config as context_vector_db_config,
|
|
20
|
+
from cognee.modules.pipelines.layers.check_pipeline_run_qualification import (
|
|
21
|
+
check_pipeline_run_qualification,
|
|
40
22
|
)
|
|
41
23
|
|
|
42
24
|
logger = get_logger("cognee.pipeline")
|
|
@@ -44,7 +26,7 @@ logger = get_logger("cognee.pipeline")
|
|
|
44
26
|
update_status_lock = asyncio.Lock()
|
|
45
27
|
|
|
46
28
|
|
|
47
|
-
async def
|
|
29
|
+
async def run_pipeline(
|
|
48
30
|
tasks: list[Task],
|
|
49
31
|
data=None,
|
|
50
32
|
datasets: Union[str, list[str], list[UUID]] = None,
|
|
@@ -54,56 +36,13 @@ async def cognee_pipeline(
|
|
|
54
36
|
graph_db_config: dict = None,
|
|
55
37
|
incremental_loading: bool = False,
|
|
56
38
|
):
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
if vector_db_config:
|
|
60
|
-
context_vector_db_config.set(vector_db_config)
|
|
61
|
-
if graph_db_config:
|
|
62
|
-
context_graph_db_config.set(graph_db_config)
|
|
63
|
-
|
|
64
|
-
# Create tables for databases
|
|
65
|
-
await create_relational_db_and_tables()
|
|
66
|
-
await create_pgvector_db_and_tables()
|
|
67
|
-
|
|
68
|
-
# Initialize first_run attribute if it doesn't exist
|
|
69
|
-
if not hasattr(cognee_pipeline, "first_run"):
|
|
70
|
-
cognee_pipeline.first_run = True
|
|
71
|
-
|
|
72
|
-
if cognee_pipeline.first_run:
|
|
73
|
-
from cognee.infrastructure.llm.utils import (
|
|
74
|
-
test_llm_connection,
|
|
75
|
-
test_embedding_connection,
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
# Test LLM and Embedding configuration once before running Cognee
|
|
79
|
-
await test_llm_connection()
|
|
80
|
-
await test_embedding_connection()
|
|
81
|
-
cognee_pipeline.first_run = False # Update flag after first run
|
|
39
|
+
validate_pipeline_tasks(tasks)
|
|
40
|
+
await setup_and_check_environment(vector_db_config, graph_db_config)
|
|
82
41
|
|
|
83
|
-
|
|
84
|
-
if user is None:
|
|
85
|
-
user = await get_default_user()
|
|
42
|
+
user, authorized_datasets = await resolve_authorized_user_datasets(datasets, user)
|
|
86
43
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
datasets = [datasets]
|
|
90
|
-
|
|
91
|
-
# Get datasets user wants write permissions for (verify user has permissions if datasets are provided as well)
|
|
92
|
-
# NOTE: If a user wants to write to a dataset he does not own it must be provided through UUID
|
|
93
|
-
existing_datasets = await get_authorized_existing_datasets(datasets, "write", user)
|
|
94
|
-
|
|
95
|
-
if not datasets:
|
|
96
|
-
# Get datasets from database if none sent.
|
|
97
|
-
datasets = existing_datasets
|
|
98
|
-
else:
|
|
99
|
-
# If dataset matches an existing Dataset (by name or id), reuse it. Otherwise, create a new Dataset.
|
|
100
|
-
datasets = await load_or_create_datasets(datasets, existing_datasets, user)
|
|
101
|
-
|
|
102
|
-
if not datasets:
|
|
103
|
-
raise DatasetNotFoundError("There are no datasets to work with.")
|
|
104
|
-
|
|
105
|
-
for dataset in datasets:
|
|
106
|
-
async for run_info in run_pipeline(
|
|
44
|
+
for dataset in authorized_datasets:
|
|
45
|
+
async for run_info in run_pipeline_per_dataset(
|
|
107
46
|
dataset=dataset,
|
|
108
47
|
user=user,
|
|
109
48
|
tasks=tasks,
|
|
@@ -115,7 +54,7 @@ async def cognee_pipeline(
|
|
|
115
54
|
yield run_info
|
|
116
55
|
|
|
117
56
|
|
|
118
|
-
async def
|
|
57
|
+
async def run_pipeline_per_dataset(
|
|
119
58
|
dataset: Dataset,
|
|
120
59
|
user: User,
|
|
121
60
|
tasks: list[Task],
|
|
@@ -124,74 +63,21 @@ async def run_pipeline(
|
|
|
124
63
|
context: dict = None,
|
|
125
64
|
incremental_loading=False,
|
|
126
65
|
):
|
|
127
|
-
check_dataset_name(dataset.name)
|
|
128
|
-
|
|
129
66
|
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
|
130
67
|
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
|
131
68
|
|
|
132
|
-
# Ugly hack, but no easier way to do this.
|
|
133
|
-
if pipeline_name == "add_pipeline":
|
|
134
|
-
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
|
135
|
-
# Refresh the add pipeline status so data is added to a dataset.
|
|
136
|
-
# Without this the app_pipeline status will be DATASET_PROCESSING_COMPLETED and will skip the execution.
|
|
137
|
-
|
|
138
|
-
await log_pipeline_run_initiated(
|
|
139
|
-
pipeline_id=pipeline_id,
|
|
140
|
-
pipeline_name="add_pipeline",
|
|
141
|
-
dataset_id=dataset.id,
|
|
142
|
-
)
|
|
143
|
-
|
|
144
|
-
# Refresh the cognify pipeline status after we add new files.
|
|
145
|
-
# Without this the cognify_pipeline status will be DATASET_PROCESSING_COMPLETED and will skip the execution.
|
|
146
|
-
await log_pipeline_run_initiated(
|
|
147
|
-
pipeline_id=pipeline_id,
|
|
148
|
-
pipeline_name="cognify_pipeline",
|
|
149
|
-
dataset_id=dataset.id,
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
dataset_id = dataset.id
|
|
153
|
-
|
|
154
69
|
if not data:
|
|
155
|
-
data: list[Data] = await get_dataset_data(dataset_id=
|
|
156
|
-
|
|
157
|
-
# async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests
|
|
158
|
-
if isinstance(dataset, Dataset):
|
|
159
|
-
task_status = await get_pipeline_status([dataset_id], pipeline_name)
|
|
160
|
-
else:
|
|
161
|
-
task_status = [
|
|
162
|
-
PipelineRunStatus.DATASET_PROCESSING_COMPLETED
|
|
163
|
-
] # TODO: this is a random assignment, find permanent solution
|
|
164
|
-
|
|
165
|
-
if str(dataset_id) in task_status:
|
|
166
|
-
if task_status[str(dataset_id)] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
|
|
167
|
-
logger.info("Dataset %s is already being processed.", dataset_id)
|
|
168
|
-
pipeline_run = await get_pipeline_run_by_dataset(dataset_id, pipeline_name)
|
|
169
|
-
yield PipelineRunStarted(
|
|
170
|
-
pipeline_run_id=pipeline_run.pipeline_run_id,
|
|
171
|
-
dataset_id=dataset.id,
|
|
172
|
-
dataset_name=dataset.name,
|
|
173
|
-
payload=data,
|
|
174
|
-
)
|
|
175
|
-
return
|
|
176
|
-
elif task_status[str(dataset_id)] == PipelineRunStatus.DATASET_PROCESSING_COMPLETED:
|
|
177
|
-
logger.info("Dataset %s is already processed.", dataset_id)
|
|
178
|
-
pipeline_run = await get_pipeline_run_by_dataset(dataset_id, pipeline_name)
|
|
179
|
-
yield PipelineRunCompleted(
|
|
180
|
-
pipeline_run_id=pipeline_run.pipeline_run_id,
|
|
181
|
-
dataset_id=dataset.id,
|
|
182
|
-
dataset_name=dataset.name,
|
|
183
|
-
)
|
|
184
|
-
return
|
|
185
|
-
|
|
186
|
-
if not isinstance(tasks, list):
|
|
187
|
-
raise ValueError("Tasks must be a list")
|
|
70
|
+
data: list[Data] = await get_dataset_data(dataset_id=dataset.id)
|
|
188
71
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
72
|
+
process_pipeline_status = await check_pipeline_run_qualification(dataset, data, pipeline_name)
|
|
73
|
+
if process_pipeline_status:
|
|
74
|
+
# If pipeline was already processed or is currently being processed
|
|
75
|
+
# return status information to async generator and finish execution
|
|
76
|
+
yield process_pipeline_status
|
|
77
|
+
return
|
|
192
78
|
|
|
193
79
|
pipeline_run = run_tasks(
|
|
194
|
-
tasks,
|
|
80
|
+
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading
|
|
195
81
|
)
|
|
196
82
|
|
|
197
83
|
async for pipeline_run_info in pipeline_run:
|
|
@@ -266,48 +266,24 @@ async def run_tasks(
|
|
|
266
266
|
if incremental_loading:
|
|
267
267
|
data = await resolve_data_directories(data)
|
|
268
268
|
|
|
269
|
-
#
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
# )
|
|
284
|
-
# )
|
|
285
|
-
# for data_item in data
|
|
286
|
-
# ]
|
|
287
|
-
# results = await asyncio.gather(*data_item_tasks)
|
|
288
|
-
# # Remove skipped data items from results
|
|
289
|
-
# results = [result for result in results if result]
|
|
290
|
-
|
|
291
|
-
### TEMP sync data item handling
|
|
292
|
-
results = []
|
|
293
|
-
# Run the pipeline for each data_item sequentially, one after the other
|
|
294
|
-
for data_item in data:
|
|
295
|
-
result = await _run_tasks_data_item(
|
|
296
|
-
data_item,
|
|
297
|
-
dataset,
|
|
298
|
-
tasks,
|
|
299
|
-
pipeline_name,
|
|
300
|
-
pipeline_id,
|
|
301
|
-
pipeline_run_id,
|
|
302
|
-
context,
|
|
303
|
-
user,
|
|
304
|
-
incremental_loading,
|
|
269
|
+
# Create async tasks per data item that will run the pipeline for the data item
|
|
270
|
+
data_item_tasks = [
|
|
271
|
+
asyncio.create_task(
|
|
272
|
+
_run_tasks_data_item(
|
|
273
|
+
data_item,
|
|
274
|
+
dataset,
|
|
275
|
+
tasks,
|
|
276
|
+
pipeline_name,
|
|
277
|
+
pipeline_id,
|
|
278
|
+
pipeline_run_id,
|
|
279
|
+
context,
|
|
280
|
+
user,
|
|
281
|
+
incremental_loading,
|
|
282
|
+
)
|
|
305
283
|
)
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
results.append(result)
|
|
310
|
-
### END
|
|
284
|
+
for data_item in data
|
|
285
|
+
]
|
|
286
|
+
results = await asyncio.gather(*data_item_tasks)
|
|
311
287
|
|
|
312
288
|
# Remove skipped data items from results
|
|
313
289
|
results = [result for result in results if result]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BaseFeedback(ABC):
|
|
6
|
+
"""Base class for all user feedback operations."""
|
|
7
|
+
|
|
8
|
+
@abstractmethod
|
|
9
|
+
async def add_feedback(self, feedback_text: str) -> Any:
|
|
10
|
+
"""Add user feedback to the system."""
|
|
11
|
+
pass
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseGraphRetriever(ABC):
|
|
8
|
+
"""Base class for all graph based retrievers."""
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
async def get_context(self, query: str) -> List[Edge]:
|
|
12
|
+
"""Retrieves triplets based on the query."""
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
async def get_completion(self, query: str, context: Optional[List[Edge]] = None) -> str:
|
|
17
|
+
"""Generates a response using the query and optional context (triplets)."""
|
|
18
|
+
pass
|
|
@@ -94,7 +94,15 @@ class CodeRetriever(BaseRetriever):
|
|
|
94
94
|
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
95
95
|
)
|
|
96
96
|
|
|
97
|
+
existing_collection = []
|
|
97
98
|
for collection in self.classes_and_functions_collections:
|
|
99
|
+
if await vector_engine.has_collection(collection):
|
|
100
|
+
existing_collection.append(collection)
|
|
101
|
+
|
|
102
|
+
if not existing_collection:
|
|
103
|
+
raise RuntimeError("No collection found for code retriever")
|
|
104
|
+
|
|
105
|
+
for collection in existing_collection:
|
|
98
106
|
logger.debug(f"Searching {collection} collection with general query")
|
|
99
107
|
search_results_code = await vector_engine.search(
|
|
100
108
|
collection, query, limit=self.top_k
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from functools import reduce
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
from cognee.shared.logging_utils import get_logger
|
|
5
|
+
from cognee.tasks.codingagents.coding_rule_associations import get_existing_rules
|
|
6
|
+
|
|
7
|
+
logger = get_logger("CodingRulesRetriever")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CodingRulesRetriever:
|
|
11
|
+
"""Retriever for handling codeing rule based searches."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, rules_nodeset_name: Optional[List[str]] = None):
|
|
14
|
+
if isinstance(rules_nodeset_name, list):
|
|
15
|
+
if not rules_nodeset_name:
|
|
16
|
+
# If there is no provided nodeset set to coding_agent_rules
|
|
17
|
+
rules_nodeset_name = ["coding_agent_rules"]
|
|
18
|
+
|
|
19
|
+
self.rules_nodeset_name = rules_nodeset_name
|
|
20
|
+
"""Initialize retriever with search parameters."""
|
|
21
|
+
|
|
22
|
+
async def get_existing_rules(self, query_text):
|
|
23
|
+
if self.rules_nodeset_name:
|
|
24
|
+
rules_list = await asyncio.gather(
|
|
25
|
+
*[
|
|
26
|
+
get_existing_rules(rules_nodeset_name=nodeset)
|
|
27
|
+
for nodeset in self.rules_nodeset_name
|
|
28
|
+
]
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
return reduce(lambda x, y: x + y, rules_list, [])
|
|
@@ -23,12 +23,14 @@ class CompletionRetriever(BaseRetriever):
|
|
|
23
23
|
self,
|
|
24
24
|
user_prompt_path: str = "context_for_question.txt",
|
|
25
25
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
26
|
+
system_prompt: Optional[str] = None,
|
|
26
27
|
top_k: Optional[int] = 1,
|
|
27
28
|
):
|
|
28
29
|
"""Initialize retriever with optional custom prompt paths."""
|
|
29
30
|
self.user_prompt_path = user_prompt_path
|
|
30
31
|
self.system_prompt_path = system_prompt_path
|
|
31
32
|
self.top_k = top_k if top_k is not None else 1
|
|
33
|
+
self.system_prompt = system_prompt
|
|
32
34
|
|
|
33
35
|
async def get_context(self, query: str) -> str:
|
|
34
36
|
"""
|
|
@@ -65,7 +67,7 @@ class CompletionRetriever(BaseRetriever):
|
|
|
65
67
|
logger.error("DocumentChunk_text collection not found")
|
|
66
68
|
raise NoDataError("No data found in the system, please add data first.") from error
|
|
67
69
|
|
|
68
|
-
async def get_completion(self, query: str, context: Optional[Any] = None) ->
|
|
70
|
+
async def get_completion(self, query: str, context: Optional[Any] = None) -> str:
|
|
69
71
|
"""
|
|
70
72
|
Generates an LLM completion using the context.
|
|
71
73
|
|
|
@@ -88,6 +90,10 @@ class CompletionRetriever(BaseRetriever):
|
|
|
88
90
|
context = await self.get_context(query)
|
|
89
91
|
|
|
90
92
|
completion = await generate_completion(
|
|
91
|
-
query,
|
|
93
|
+
query=query,
|
|
94
|
+
context=context,
|
|
95
|
+
user_prompt_path=self.user_prompt_path,
|
|
96
|
+
system_prompt_path=self.system_prompt_path,
|
|
97
|
+
system_prompt=self.system_prompt,
|
|
92
98
|
)
|
|
93
|
-
return
|
|
99
|
+
return completion
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from typing import Any, Optional
|
|
2
2
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
3
|
-
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
|
|
4
3
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
5
4
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
6
5
|
from cognee.modules.retrieval.exceptions import SearchTypeNotSupported, CypherSearchError
|
|
@@ -31,8 +30,7 @@ class CypherSearchRetriever(BaseRetriever):
|
|
|
31
30
|
"""
|
|
32
31
|
Retrieves relevant context using a cypher query.
|
|
33
32
|
|
|
34
|
-
If
|
|
35
|
-
any error occurs during execution, logs the error and raises CypherSearchError.
|
|
33
|
+
If any error occurs during execution, logs the error and raises CypherSearchError.
|
|
36
34
|
|
|
37
35
|
Parameters:
|
|
38
36
|
-----------
|
|
@@ -46,12 +44,6 @@ class CypherSearchRetriever(BaseRetriever):
|
|
|
46
44
|
"""
|
|
47
45
|
try:
|
|
48
46
|
graph_engine = await get_graph_engine()
|
|
49
|
-
|
|
50
|
-
if isinstance(graph_engine, NetworkXAdapter):
|
|
51
|
-
raise SearchTypeNotSupported(
|
|
52
|
-
"CYPHER search type not supported for NetworkXAdapter."
|
|
53
|
-
)
|
|
54
|
-
|
|
55
47
|
result = await graph_engine.query(query)
|
|
56
48
|
except Exception as e:
|
|
57
49
|
logger.error("Failed to execture cypher search retrieval: %s", str(e))
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Optional, List, Type
|
|
2
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
2
3
|
from cognee.shared.logging_utils import get_logger
|
|
3
4
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
4
5
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
@@ -26,9 +27,11 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
26
27
|
self,
|
|
27
28
|
user_prompt_path: str = "graph_context_for_question.txt",
|
|
28
29
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
30
|
+
system_prompt: Optional[str] = None,
|
|
29
31
|
top_k: Optional[int] = 5,
|
|
30
32
|
node_type: Optional[Type] = None,
|
|
31
33
|
node_name: Optional[List[str]] = None,
|
|
34
|
+
save_interaction: bool = False,
|
|
32
35
|
):
|
|
33
36
|
super().__init__(
|
|
34
37
|
user_prompt_path=user_prompt_path,
|
|
@@ -36,11 +39,16 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
36
39
|
top_k=top_k,
|
|
37
40
|
node_type=node_type,
|
|
38
41
|
node_name=node_name,
|
|
42
|
+
save_interaction=save_interaction,
|
|
43
|
+
system_prompt=system_prompt,
|
|
39
44
|
)
|
|
40
45
|
|
|
41
46
|
async def get_completion(
|
|
42
|
-
self,
|
|
43
|
-
|
|
47
|
+
self,
|
|
48
|
+
query: str,
|
|
49
|
+
context: Optional[List[Edge]] = None,
|
|
50
|
+
context_extension_rounds=4,
|
|
51
|
+
) -> str:
|
|
44
52
|
"""
|
|
45
53
|
Extends the context for a given query by retrieving related triplets and generating new
|
|
46
54
|
completions based on them.
|
|
@@ -65,11 +73,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
65
73
|
- List[str]: A list containing the generated answer based on the query and the
|
|
66
74
|
extended context.
|
|
67
75
|
"""
|
|
68
|
-
triplets =
|
|
76
|
+
triplets = context
|
|
77
|
+
|
|
78
|
+
if triplets is None:
|
|
79
|
+
triplets = await self.get_context(query)
|
|
69
80
|
|
|
70
|
-
|
|
71
|
-
triplets += await self.get_triplets(query)
|
|
72
|
-
context = await self.resolve_edges_to_text(triplets)
|
|
81
|
+
context_text = await self.resolve_edges_to_text(triplets)
|
|
73
82
|
|
|
74
83
|
round_idx = 1
|
|
75
84
|
|
|
@@ -81,14 +90,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
81
90
|
)
|
|
82
91
|
completion = await generate_completion(
|
|
83
92
|
query=query,
|
|
84
|
-
context=
|
|
93
|
+
context=context_text,
|
|
85
94
|
user_prompt_path=self.user_prompt_path,
|
|
86
95
|
system_prompt_path=self.system_prompt_path,
|
|
96
|
+
system_prompt=self.system_prompt,
|
|
87
97
|
)
|
|
88
98
|
|
|
89
|
-
triplets += await self.
|
|
99
|
+
triplets += await self.get_context(completion)
|
|
90
100
|
triplets = list(set(triplets))
|
|
91
|
-
|
|
101
|
+
context_text = await self.resolve_edges_to_text(triplets)
|
|
92
102
|
|
|
93
103
|
num_triplets = len(triplets)
|
|
94
104
|
|
|
@@ -105,11 +115,17 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
105
115
|
|
|
106
116
|
round_idx += 1
|
|
107
117
|
|
|
108
|
-
|
|
118
|
+
completion = await generate_completion(
|
|
109
119
|
query=query,
|
|
110
|
-
context=
|
|
120
|
+
context=context_text,
|
|
111
121
|
user_prompt_path=self.user_prompt_path,
|
|
112
122
|
system_prompt_path=self.system_prompt_path,
|
|
123
|
+
system_prompt=self.system_prompt,
|
|
113
124
|
)
|
|
114
125
|
|
|
115
|
-
|
|
126
|
+
if self.save_interaction and context_text and triplets and completion:
|
|
127
|
+
await self.save_qa(
|
|
128
|
+
question=query, answer=completion, context=context_text, triplets=triplets
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return completion
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Optional, List, Type
|
|
2
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
2
3
|
from cognee.shared.logging_utils import get_logger
|
|
3
4
|
|
|
4
5
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
@@ -32,16 +33,20 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
32
33
|
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
|
|
33
34
|
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
|
|
34
35
|
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
|
|
36
|
+
system_prompt: Optional[str] = None,
|
|
35
37
|
top_k: Optional[int] = 5,
|
|
36
38
|
node_type: Optional[Type] = None,
|
|
37
39
|
node_name: Optional[List[str]] = None,
|
|
40
|
+
save_interaction: bool = False,
|
|
38
41
|
):
|
|
39
42
|
super().__init__(
|
|
40
43
|
user_prompt_path=user_prompt_path,
|
|
41
44
|
system_prompt_path=system_prompt_path,
|
|
45
|
+
system_prompt=system_prompt,
|
|
42
46
|
top_k=top_k,
|
|
43
47
|
node_type=node_type,
|
|
44
48
|
node_name=node_name,
|
|
49
|
+
save_interaction=save_interaction,
|
|
45
50
|
)
|
|
46
51
|
self.validation_system_prompt_path = validation_system_prompt_path
|
|
47
52
|
self.validation_user_prompt_path = validation_user_prompt_path
|
|
@@ -49,8 +54,11 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
49
54
|
self.followup_user_prompt_path = followup_user_prompt_path
|
|
50
55
|
|
|
51
56
|
async def get_completion(
|
|
52
|
-
self,
|
|
53
|
-
|
|
57
|
+
self,
|
|
58
|
+
query: str,
|
|
59
|
+
context: Optional[List[Edge]] = None,
|
|
60
|
+
max_iter=4,
|
|
61
|
+
) -> str:
|
|
54
62
|
"""
|
|
55
63
|
Generate completion responses based on a user query and contextual information.
|
|
56
64
|
|
|
@@ -75,25 +83,29 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
75
83
|
"""
|
|
76
84
|
followup_question = ""
|
|
77
85
|
triplets = []
|
|
78
|
-
|
|
86
|
+
completion = ""
|
|
79
87
|
|
|
80
88
|
for round_idx in range(max_iter + 1):
|
|
81
89
|
if round_idx == 0:
|
|
82
90
|
if context is None:
|
|
83
|
-
|
|
91
|
+
triplets = await self.get_context(query)
|
|
92
|
+
context_text = await self.resolve_edges_to_text(triplets)
|
|
93
|
+
else:
|
|
94
|
+
context_text = await self.resolve_edges_to_text(context)
|
|
84
95
|
else:
|
|
85
|
-
triplets += await self.
|
|
86
|
-
|
|
96
|
+
triplets += await self.get_context(followup_question)
|
|
97
|
+
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
|
87
98
|
|
|
88
|
-
|
|
99
|
+
completion = await generate_completion(
|
|
89
100
|
query=query,
|
|
90
|
-
context=
|
|
101
|
+
context=context_text,
|
|
91
102
|
user_prompt_path=self.user_prompt_path,
|
|
92
103
|
system_prompt_path=self.system_prompt_path,
|
|
104
|
+
system_prompt=self.system_prompt,
|
|
93
105
|
)
|
|
94
|
-
logger.info(f"Chain-of-thought: round {round_idx} - answer: {
|
|
106
|
+
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
|
95
107
|
if round_idx < max_iter:
|
|
96
|
-
valid_args = {"query": query, "answer":
|
|
108
|
+
valid_args = {"query": query, "answer": completion, "context": context_text}
|
|
97
109
|
valid_user_prompt = LLMGateway.render_prompt(
|
|
98
110
|
filename=self.validation_user_prompt_path, context=valid_args
|
|
99
111
|
)
|
|
@@ -106,7 +118,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
106
118
|
system_prompt=valid_system_prompt,
|
|
107
119
|
response_model=str,
|
|
108
120
|
)
|
|
109
|
-
followup_args = {"query": query, "answer":
|
|
121
|
+
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
|
|
110
122
|
followup_prompt = LLMGateway.render_prompt(
|
|
111
123
|
filename=self.followup_user_prompt_path, context=followup_args
|
|
112
124
|
)
|
|
@@ -121,4 +133,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|
|
121
133
|
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
|
122
134
|
)
|
|
123
135
|
|
|
124
|
-
|
|
136
|
+
if self.save_interaction and context and triplets and completion:
|
|
137
|
+
await self.save_qa(
|
|
138
|
+
question=query, answer=completion, context=context_text, triplets=triplets
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return completion
|