cognee 0.2.3.dev1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/__main__.py +4 -0
- cognee/api/client.py +28 -3
- cognee/api/health.py +10 -13
- cognee/api/v1/add/add.py +20 -6
- cognee/api/v1/add/routers/get_add_router.py +12 -37
- cognee/api/v1/cloud/routers/__init__.py +1 -0
- cognee/api/v1/cloud/routers/get_checks_router.py +23 -0
- cognee/api/v1/cognify/code_graph_pipeline.py +14 -3
- cognee/api/v1/cognify/cognify.py +67 -105
- cognee/api/v1/cognify/routers/get_cognify_router.py +11 -3
- cognee/api/v1/datasets/routers/get_datasets_router.py +16 -5
- cognee/api/v1/memify/routers/__init__.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +100 -0
- cognee/api/v1/notebooks/routers/__init__.py +1 -0
- cognee/api/v1/notebooks/routers/get_notebooks_router.py +96 -0
- cognee/api/v1/responses/default_tools.py +4 -0
- cognee/api/v1/responses/dispatch_function.py +6 -1
- cognee/api/v1/responses/models.py +1 -1
- cognee/api/v1/search/routers/get_search_router.py +20 -1
- cognee/api/v1/search/search.py +17 -4
- cognee/api/v1/sync/__init__.py +17 -0
- cognee/api/v1/sync/routers/__init__.py +3 -0
- cognee/api/v1/sync/routers/get_sync_router.py +241 -0
- cognee/api/v1/sync/sync.py +877 -0
- cognee/api/v1/ui/__init__.py +1 -0
- cognee/api/v1/ui/ui.py +529 -0
- cognee/api/v1/users/routers/get_auth_router.py +13 -1
- cognee/base_config.py +10 -1
- cognee/cli/__init__.py +10 -0
- cognee/cli/_cognee.py +273 -0
- cognee/cli/commands/__init__.py +1 -0
- cognee/cli/commands/add_command.py +80 -0
- cognee/cli/commands/cognify_command.py +128 -0
- cognee/cli/commands/config_command.py +225 -0
- cognee/cli/commands/delete_command.py +80 -0
- cognee/cli/commands/search_command.py +149 -0
- cognee/cli/config.py +33 -0
- cognee/cli/debug.py +21 -0
- cognee/cli/echo.py +45 -0
- cognee/cli/exceptions.py +23 -0
- cognee/cli/minimal_cli.py +97 -0
- cognee/cli/reference.py +26 -0
- cognee/cli/suppress_logging.py +12 -0
- cognee/eval_framework/corpus_builder/corpus_builder_executor.py +2 -2
- cognee/eval_framework/eval_config.py +1 -1
- cognee/infrastructure/databases/graph/config.py +10 -4
- cognee/infrastructure/databases/graph/get_graph_engine.py +4 -9
- cognee/infrastructure/databases/graph/kuzu/adapter.py +199 -2
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +138 -0
- cognee/infrastructure/databases/relational/__init__.py +2 -0
- cognee/infrastructure/databases/relational/get_async_session.py +15 -0
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +6 -1
- cognee/infrastructure/databases/relational/with_async_session.py +25 -0
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +1 -1
- cognee/infrastructure/databases/vector/config.py +13 -6
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +6 -4
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +16 -7
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +5 -5
- cognee/infrastructure/databases/vector/embeddings/config.py +2 -2
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +2 -6
- cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +10 -7
- cognee/infrastructure/files/storage/LocalFileStorage.py +9 -0
- cognee/infrastructure/files/storage/S3FileStorage.py +5 -0
- cognee/infrastructure/files/storage/StorageManager.py +7 -1
- cognee/infrastructure/files/storage/storage.py +16 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +14 -9
- cognee/infrastructure/files/utils/get_file_metadata.py +2 -1
- cognee/infrastructure/llm/LLMGateway.py +32 -5
- cognee/infrastructure/llm/config.py +6 -4
- cognee/infrastructure/llm/prompts/extract_query_time.txt +15 -0
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +25 -0
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +30 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py +16 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +2 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_event_entities.py +44 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/__init__.py +1 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_content_graph.py +19 -15
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_event_graph.py +46 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +3 -3
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +3 -3
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +14 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +6 -4
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +28 -4
- cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +2 -2
- cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py +3 -3
- cognee/infrastructure/llm/tokenizer/Mistral/adapter.py +3 -3
- cognee/infrastructure/llm/tokenizer/TikToken/adapter.py +6 -6
- cognee/infrastructure/llm/utils.py +7 -7
- cognee/infrastructure/utils/run_sync.py +8 -1
- cognee/modules/chunking/models/DocumentChunk.py +4 -3
- cognee/modules/cloud/exceptions/CloudApiKeyMissingError.py +15 -0
- cognee/modules/cloud/exceptions/CloudConnectionError.py +15 -0
- cognee/modules/cloud/exceptions/__init__.py +2 -0
- cognee/modules/cloud/operations/__init__.py +1 -0
- cognee/modules/cloud/operations/check_api_key.py +25 -0
- cognee/modules/data/deletion/prune_system.py +1 -1
- cognee/modules/data/methods/__init__.py +2 -0
- cognee/modules/data/methods/check_dataset_name.py +1 -1
- cognee/modules/data/methods/create_authorized_dataset.py +19 -0
- cognee/modules/data/methods/get_authorized_dataset.py +11 -5
- cognee/modules/data/methods/get_authorized_dataset_by_name.py +16 -0
- cognee/modules/data/methods/get_dataset_data.py +1 -1
- cognee/modules/data/methods/load_or_create_datasets.py +2 -20
- cognee/modules/engine/models/Event.py +16 -0
- cognee/modules/engine/models/Interval.py +8 -0
- cognee/modules/engine/models/Timestamp.py +13 -0
- cognee/modules/engine/models/__init__.py +3 -0
- cognee/modules/engine/utils/__init__.py +2 -0
- cognee/modules/engine/utils/generate_event_datapoint.py +46 -0
- cognee/modules/engine/utils/generate_timestamp_datapoint.py +51 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +2 -2
- cognee/modules/graph/methods/get_formatted_graph_data.py +3 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/resolve_edges_to_text.py +71 -0
- cognee/modules/memify/__init__.py +1 -0
- cognee/modules/memify/memify.py +118 -0
- cognee/modules/notebooks/methods/__init__.py +5 -0
- cognee/modules/notebooks/methods/create_notebook.py +26 -0
- cognee/modules/notebooks/methods/delete_notebook.py +13 -0
- cognee/modules/notebooks/methods/get_notebook.py +21 -0
- cognee/modules/notebooks/methods/get_notebooks.py +18 -0
- cognee/modules/notebooks/methods/update_notebook.py +17 -0
- cognee/modules/notebooks/models/Notebook.py +53 -0
- cognee/modules/notebooks/models/__init__.py +1 -0
- cognee/modules/notebooks/operations/__init__.py +1 -0
- cognee/modules/notebooks/operations/run_in_local_sandbox.py +55 -0
- cognee/modules/pipelines/__init__.py +1 -1
- cognee/modules/pipelines/exceptions/tasks.py +18 -0
- cognee/modules/pipelines/layers/__init__.py +1 -0
- cognee/modules/pipelines/layers/check_pipeline_run_qualification.py +59 -0
- cognee/modules/pipelines/layers/pipeline_execution_mode.py +127 -0
- cognee/modules/pipelines/layers/reset_dataset_pipeline_run_status.py +28 -0
- cognee/modules/pipelines/layers/resolve_authorized_user_dataset.py +34 -0
- cognee/modules/pipelines/layers/resolve_authorized_user_datasets.py +55 -0
- cognee/modules/pipelines/layers/setup_and_check_environment.py +41 -0
- cognee/modules/pipelines/layers/validate_pipeline_tasks.py +20 -0
- cognee/modules/pipelines/methods/__init__.py +2 -0
- cognee/modules/pipelines/methods/get_pipeline_runs_by_dataset.py +34 -0
- cognee/modules/pipelines/methods/reset_pipeline_run_status.py +16 -0
- cognee/modules/pipelines/operations/__init__.py +0 -1
- cognee/modules/pipelines/operations/log_pipeline_run_initiated.py +1 -1
- cognee/modules/pipelines/operations/pipeline.py +24 -138
- cognee/modules/pipelines/operations/run_tasks.py +17 -41
- cognee/modules/retrieval/base_feedback.py +11 -0
- cognee/modules/retrieval/base_graph_retriever.py +18 -0
- cognee/modules/retrieval/base_retriever.py +1 -1
- cognee/modules/retrieval/code_retriever.py +8 -0
- cognee/modules/retrieval/coding_rules_retriever.py +31 -0
- cognee/modules/retrieval/completion_retriever.py +9 -3
- cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +1 -0
- cognee/modules/retrieval/cypher_search_retriever.py +1 -9
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +29 -13
- cognee/modules/retrieval/graph_completion_cot_retriever.py +30 -13
- cognee/modules/retrieval/graph_completion_retriever.py +107 -56
- cognee/modules/retrieval/graph_summary_completion_retriever.py +5 -1
- cognee/modules/retrieval/insights_retriever.py +14 -3
- cognee/modules/retrieval/natural_language_retriever.py +0 -4
- cognee/modules/retrieval/summaries_retriever.py +1 -1
- cognee/modules/retrieval/temporal_retriever.py +152 -0
- cognee/modules/retrieval/user_qa_feedback.py +83 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +7 -32
- cognee/modules/retrieval/utils/completion.py +10 -3
- cognee/modules/retrieval/utils/extract_uuid_from_node.py +18 -0
- cognee/modules/retrieval/utils/models.py +40 -0
- cognee/modules/search/methods/get_search_type_tools.py +168 -0
- cognee/modules/search/methods/no_access_control_search.py +47 -0
- cognee/modules/search/methods/search.py +239 -118
- cognee/modules/search/types/SearchResult.py +21 -0
- cognee/modules/search/types/SearchType.py +3 -0
- cognee/modules/search/types/__init__.py +1 -0
- cognee/modules/search/utils/__init__.py +2 -0
- cognee/modules/search/utils/prepare_search_result.py +41 -0
- cognee/modules/search/utils/transform_context_to_graph.py +38 -0
- cognee/modules/settings/get_settings.py +2 -2
- cognee/modules/sync/__init__.py +1 -0
- cognee/modules/sync/methods/__init__.py +23 -0
- cognee/modules/sync/methods/create_sync_operation.py +53 -0
- cognee/modules/sync/methods/get_sync_operation.py +107 -0
- cognee/modules/sync/methods/update_sync_operation.py +248 -0
- cognee/modules/sync/models/SyncOperation.py +142 -0
- cognee/modules/sync/models/__init__.py +3 -0
- cognee/modules/users/__init__.py +0 -1
- cognee/modules/users/methods/__init__.py +4 -1
- cognee/modules/users/methods/create_user.py +26 -1
- cognee/modules/users/methods/get_authenticated_user.py +36 -42
- cognee/modules/users/methods/get_default_user.py +3 -1
- cognee/modules/users/permissions/methods/get_specific_user_permission_datasets.py +2 -1
- cognee/root_dir.py +19 -0
- cognee/shared/CodeGraphEntities.py +1 -0
- cognee/shared/logging_utils.py +143 -32
- cognee/shared/utils.py +0 -1
- cognee/tasks/codingagents/coding_rule_associations.py +127 -0
- cognee/tasks/graph/extract_graph_from_data.py +6 -2
- cognee/tasks/ingestion/save_data_item_to_storage.py +23 -0
- cognee/tasks/memify/__init__.py +2 -0
- cognee/tasks/memify/extract_subgraph.py +7 -0
- cognee/tasks/memify/extract_subgraph_chunks.py +11 -0
- cognee/tasks/repo_processor/get_local_dependencies.py +2 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +144 -47
- cognee/tasks/storage/add_data_points.py +33 -3
- cognee/tasks/temporal_graph/__init__.py +1 -0
- cognee/tasks/temporal_graph/add_entities_to_event.py +85 -0
- cognee/tasks/temporal_graph/enrich_events.py +34 -0
- cognee/tasks/temporal_graph/extract_events_and_entities.py +32 -0
- cognee/tasks/temporal_graph/extract_knowledge_graph_from_events.py +41 -0
- cognee/tasks/temporal_graph/models.py +49 -0
- cognee/tests/integration/cli/__init__.py +3 -0
- cognee/tests/integration/cli/test_cli_integration.py +331 -0
- cognee/tests/integration/documents/PdfDocument_test.py +2 -2
- cognee/tests/integration/documents/TextDocument_test.py +2 -4
- cognee/tests/integration/documents/UnstructuredDocument_test.py +5 -8
- cognee/tests/{test_deletion.py → test_delete_hard.py} +0 -37
- cognee/tests/test_delete_soft.py +85 -0
- cognee/tests/test_kuzu.py +2 -2
- cognee/tests/test_neo4j.py +2 -2
- cognee/tests/test_permissions.py +3 -3
- cognee/tests/test_relational_db_migration.py +7 -5
- cognee/tests/test_search_db.py +136 -23
- cognee/tests/test_temporal_graph.py +167 -0
- cognee/tests/unit/api/__init__.py +1 -0
- cognee/tests/unit/api/test_conditional_authentication_endpoints.py +246 -0
- cognee/tests/unit/cli/__init__.py +3 -0
- cognee/tests/unit/cli/test_cli_commands.py +483 -0
- cognee/tests/unit/cli/test_cli_edge_cases.py +625 -0
- cognee/tests/unit/cli/test_cli_main.py +173 -0
- cognee/tests/unit/cli/test_cli_runner.py +62 -0
- cognee/tests/unit/cli/test_cli_utils.py +127 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +12 -15
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +10 -15
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +4 -3
- cognee/tests/unit/modules/retrieval/insights_retriever_test.py +4 -2
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +225 -0
- cognee/tests/unit/modules/users/__init__.py +1 -0
- cognee/tests/unit/modules/users/test_conditional_authentication.py +277 -0
- cognee/tests/unit/processing/utils/utils_test.py +20 -1
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/METADATA +13 -9
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/RECORD +247 -135
- cognee-0.3.0.dist-info/entry_points.txt +2 -0
- cognee/infrastructure/databases/graph/networkx/adapter.py +0 -1017
- cognee/infrastructure/pipeline/models/Operation.py +0 -60
- cognee/notebooks/github_analysis_step_by_step.ipynb +0 -37
- cognee/tests/tasks/descriptive_metrics/networkx_metrics_test.py +0 -7
- cognee/tests/unit/modules/search/search_methods_test.py +0 -223
- /cognee/{infrastructure/databases/graph/networkx → api/v1/memify}/__init__.py +0 -0
- /cognee/{infrastructure/pipeline/models → tasks/codingagents}/__init__.py +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/WHEEL +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from uuid import UUID
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
from sqlalchemy import select, desc, and_
|
|
4
|
+
from cognee.modules.sync.models import SyncOperation, SyncStatus
|
|
5
|
+
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
async def get_sync_operation(run_id: str) -> Optional[SyncOperation]:
|
|
9
|
+
"""
|
|
10
|
+
Get a sync operation by its run_id.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
run_id: The public run_id of the sync operation
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
SyncOperation: The sync operation record, or None if not found
|
|
17
|
+
"""
|
|
18
|
+
db_engine = get_relational_engine()
|
|
19
|
+
|
|
20
|
+
async with db_engine.get_async_session() as session:
|
|
21
|
+
query = select(SyncOperation).where(SyncOperation.run_id == run_id)
|
|
22
|
+
result = await session.execute(query)
|
|
23
|
+
return result.scalars().first()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def get_user_sync_operations(
|
|
27
|
+
user_id: UUID, limit: int = 50, offset: int = 0
|
|
28
|
+
) -> List[SyncOperation]:
|
|
29
|
+
"""
|
|
30
|
+
Get sync operations for a specific user, ordered by most recent first.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
user_id: UUID of the user
|
|
34
|
+
limit: Maximum number of records to return
|
|
35
|
+
offset: Number of records to skip
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
List[SyncOperation]: List of sync operations for the user
|
|
39
|
+
"""
|
|
40
|
+
db_engine = get_relational_engine()
|
|
41
|
+
|
|
42
|
+
async with db_engine.get_async_session() as session:
|
|
43
|
+
query = (
|
|
44
|
+
select(SyncOperation)
|
|
45
|
+
.where(SyncOperation.user_id == user_id)
|
|
46
|
+
.order_by(desc(SyncOperation.created_at))
|
|
47
|
+
.limit(limit)
|
|
48
|
+
.offset(offset)
|
|
49
|
+
)
|
|
50
|
+
result = await session.execute(query)
|
|
51
|
+
return list(result.scalars().all())
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def get_sync_operations_by_dataset(
|
|
55
|
+
dataset_id: UUID, limit: int = 50, offset: int = 0
|
|
56
|
+
) -> List[SyncOperation]:
|
|
57
|
+
"""
|
|
58
|
+
Get sync operations for a specific dataset.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
dataset_id: UUID of the dataset
|
|
62
|
+
limit: Maximum number of records to return
|
|
63
|
+
offset: Number of records to skip
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
List[SyncOperation]: List of sync operations for the dataset
|
|
67
|
+
"""
|
|
68
|
+
db_engine = get_relational_engine()
|
|
69
|
+
|
|
70
|
+
async with db_engine.get_async_session() as session:
|
|
71
|
+
query = (
|
|
72
|
+
select(SyncOperation)
|
|
73
|
+
.where(SyncOperation.dataset_id == dataset_id)
|
|
74
|
+
.order_by(desc(SyncOperation.created_at))
|
|
75
|
+
.limit(limit)
|
|
76
|
+
.offset(offset)
|
|
77
|
+
)
|
|
78
|
+
result = await session.execute(query)
|
|
79
|
+
return list(result.scalars().all())
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
async def get_running_sync_operations_for_user(user_id: UUID) -> List[SyncOperation]:
|
|
83
|
+
"""
|
|
84
|
+
Get all currently running sync operations for a specific user.
|
|
85
|
+
Checks for operations with STARTED or IN_PROGRESS status.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
user_id: UUID of the user
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
List[SyncOperation]: List of running sync operations for the user
|
|
92
|
+
"""
|
|
93
|
+
db_engine = get_relational_engine()
|
|
94
|
+
|
|
95
|
+
async with db_engine.get_async_session() as session:
|
|
96
|
+
query = (
|
|
97
|
+
select(SyncOperation)
|
|
98
|
+
.where(
|
|
99
|
+
and_(
|
|
100
|
+
SyncOperation.user_id == user_id,
|
|
101
|
+
SyncOperation.status.in_([SyncStatus.STARTED, SyncStatus.IN_PROGRESS]),
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
.order_by(desc(SyncOperation.created_at))
|
|
105
|
+
)
|
|
106
|
+
result = await session.execute(query)
|
|
107
|
+
return list(result.scalars().all())
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import Optional, List
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from sqlalchemy import select
|
|
5
|
+
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError, TimeoutError
|
|
6
|
+
from cognee.modules.sync.models import SyncOperation, SyncStatus
|
|
7
|
+
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
8
|
+
from cognee.shared.logging_utils import get_logger
|
|
9
|
+
from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
|
|
10
|
+
|
|
11
|
+
logger = get_logger("sync.db_operations")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
async def _retry_db_operation(operation_func, run_id: str, max_retries: int = 3):
|
|
15
|
+
"""
|
|
16
|
+
Retry database operations with exponential backoff for transient failures.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
operation_func: Async function to retry
|
|
20
|
+
run_id: Run ID for logging context
|
|
21
|
+
max_retries: Maximum number of retry attempts
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Result of the operation function
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
Exception: Re-raises the last exception if all retries fail
|
|
28
|
+
"""
|
|
29
|
+
attempt = 0
|
|
30
|
+
last_exception = None
|
|
31
|
+
|
|
32
|
+
while attempt < max_retries:
|
|
33
|
+
try:
|
|
34
|
+
return await operation_func()
|
|
35
|
+
except (DisconnectionError, OperationalError, TimeoutError) as e:
|
|
36
|
+
attempt += 1
|
|
37
|
+
last_exception = e
|
|
38
|
+
|
|
39
|
+
if attempt >= max_retries:
|
|
40
|
+
logger.error(
|
|
41
|
+
f"Database operation failed after {max_retries} attempts for run_id {run_id}: {str(e)}"
|
|
42
|
+
)
|
|
43
|
+
break
|
|
44
|
+
|
|
45
|
+
backoff_time = calculate_backoff(attempt - 1) # calculate_backoff is 0-indexed
|
|
46
|
+
logger.warning(
|
|
47
|
+
f"Database operation failed for run_id {run_id}, retrying in {backoff_time:.2f}s (attempt {attempt}/{max_retries}): {str(e)}"
|
|
48
|
+
)
|
|
49
|
+
await asyncio.sleep(backoff_time)
|
|
50
|
+
|
|
51
|
+
except Exception as e:
|
|
52
|
+
# Non-transient errors should not be retried
|
|
53
|
+
logger.error(f"Non-retryable database error for run_id {run_id}: {str(e)}")
|
|
54
|
+
raise
|
|
55
|
+
|
|
56
|
+
# If we get here, all retries failed
|
|
57
|
+
raise last_exception
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
async def update_sync_operation(
|
|
61
|
+
run_id: str,
|
|
62
|
+
status: Optional[SyncStatus] = None,
|
|
63
|
+
progress_percentage: Optional[int] = None,
|
|
64
|
+
records_downloaded: Optional[int] = None,
|
|
65
|
+
total_records_to_sync: Optional[int] = None,
|
|
66
|
+
total_records_to_download: Optional[int] = None,
|
|
67
|
+
total_records_to_upload: Optional[int] = None,
|
|
68
|
+
records_uploaded: Optional[int] = None,
|
|
69
|
+
bytes_downloaded: Optional[int] = None,
|
|
70
|
+
bytes_uploaded: Optional[int] = None,
|
|
71
|
+
dataset_sync_hashes: Optional[dict] = None,
|
|
72
|
+
error_message: Optional[str] = None,
|
|
73
|
+
retry_count: Optional[int] = None,
|
|
74
|
+
started_at: Optional[datetime] = None,
|
|
75
|
+
completed_at: Optional[datetime] = None,
|
|
76
|
+
) -> Optional[SyncOperation]:
|
|
77
|
+
"""
|
|
78
|
+
Update a sync operation record with new status/progress information.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
run_id: The public run_id of the sync operation to update
|
|
82
|
+
status: New status for the operation
|
|
83
|
+
progress_percentage: Progress percentage (0-100)
|
|
84
|
+
records_downloaded: Number of records downloaded so far
|
|
85
|
+
total_records_to_sync: Total number of records that need to be synced
|
|
86
|
+
total_records_to_download: Total number of records to download from cloud
|
|
87
|
+
total_records_to_upload: Total number of records to upload to cloud
|
|
88
|
+
records_uploaded: Number of records uploaded so far
|
|
89
|
+
bytes_downloaded: Total bytes downloaded from cloud
|
|
90
|
+
bytes_uploaded: Total bytes uploaded to cloud
|
|
91
|
+
dataset_sync_hashes: Dict mapping dataset_id -> {uploaded: [hashes], downloaded: [hashes]}
|
|
92
|
+
error_message: Error message if operation failed
|
|
93
|
+
retry_count: Number of retry attempts
|
|
94
|
+
started_at: When the actual processing started
|
|
95
|
+
completed_at: When the operation completed (success or failure)
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
SyncOperation: The updated sync operation record, or None if not found
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
async def _perform_update():
|
|
102
|
+
db_engine = get_relational_engine()
|
|
103
|
+
|
|
104
|
+
async with db_engine.get_async_session() as session:
|
|
105
|
+
try:
|
|
106
|
+
# Find the sync operation
|
|
107
|
+
query = select(SyncOperation).where(SyncOperation.run_id == run_id)
|
|
108
|
+
result = await session.execute(query)
|
|
109
|
+
sync_operation = result.scalars().first()
|
|
110
|
+
|
|
111
|
+
if not sync_operation:
|
|
112
|
+
logger.warning(f"Sync operation not found for run_id: {run_id}")
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
# Log what we're updating for debugging
|
|
116
|
+
updates = []
|
|
117
|
+
if status is not None:
|
|
118
|
+
updates.append(f"status={status.value}")
|
|
119
|
+
if progress_percentage is not None:
|
|
120
|
+
updates.append(f"progress={progress_percentage}%")
|
|
121
|
+
if records_downloaded is not None:
|
|
122
|
+
updates.append(f"downloaded={records_downloaded}")
|
|
123
|
+
if records_uploaded is not None:
|
|
124
|
+
updates.append(f"uploaded={records_uploaded}")
|
|
125
|
+
if total_records_to_sync is not None:
|
|
126
|
+
updates.append(f"total_sync={total_records_to_sync}")
|
|
127
|
+
if total_records_to_download is not None:
|
|
128
|
+
updates.append(f"total_download={total_records_to_download}")
|
|
129
|
+
if total_records_to_upload is not None:
|
|
130
|
+
updates.append(f"total_upload={total_records_to_upload}")
|
|
131
|
+
|
|
132
|
+
if updates:
|
|
133
|
+
logger.debug(f"Updating sync operation {run_id}: {', '.join(updates)}")
|
|
134
|
+
|
|
135
|
+
# Update fields that were provided
|
|
136
|
+
if status is not None:
|
|
137
|
+
sync_operation.status = status
|
|
138
|
+
|
|
139
|
+
if progress_percentage is not None:
|
|
140
|
+
sync_operation.progress_percentage = max(0, min(100, progress_percentage))
|
|
141
|
+
|
|
142
|
+
if records_downloaded is not None:
|
|
143
|
+
sync_operation.records_downloaded = records_downloaded
|
|
144
|
+
|
|
145
|
+
if records_uploaded is not None:
|
|
146
|
+
sync_operation.records_uploaded = records_uploaded
|
|
147
|
+
|
|
148
|
+
if total_records_to_sync is not None:
|
|
149
|
+
sync_operation.total_records_to_sync = total_records_to_sync
|
|
150
|
+
|
|
151
|
+
if total_records_to_download is not None:
|
|
152
|
+
sync_operation.total_records_to_download = total_records_to_download
|
|
153
|
+
|
|
154
|
+
if total_records_to_upload is not None:
|
|
155
|
+
sync_operation.total_records_to_upload = total_records_to_upload
|
|
156
|
+
|
|
157
|
+
if bytes_downloaded is not None:
|
|
158
|
+
sync_operation.bytes_downloaded = bytes_downloaded
|
|
159
|
+
|
|
160
|
+
if bytes_uploaded is not None:
|
|
161
|
+
sync_operation.bytes_uploaded = bytes_uploaded
|
|
162
|
+
|
|
163
|
+
if dataset_sync_hashes is not None:
|
|
164
|
+
sync_operation.dataset_sync_hashes = dataset_sync_hashes
|
|
165
|
+
|
|
166
|
+
if error_message is not None:
|
|
167
|
+
sync_operation.error_message = error_message
|
|
168
|
+
|
|
169
|
+
if retry_count is not None:
|
|
170
|
+
sync_operation.retry_count = retry_count
|
|
171
|
+
|
|
172
|
+
if started_at is not None:
|
|
173
|
+
sync_operation.started_at = started_at
|
|
174
|
+
|
|
175
|
+
if completed_at is not None:
|
|
176
|
+
sync_operation.completed_at = completed_at
|
|
177
|
+
|
|
178
|
+
# Auto-set completion timestamp for terminal statuses
|
|
179
|
+
if (
|
|
180
|
+
status in [SyncStatus.COMPLETED, SyncStatus.FAILED, SyncStatus.CANCELLED]
|
|
181
|
+
and completed_at is None
|
|
182
|
+
):
|
|
183
|
+
sync_operation.completed_at = datetime.now(timezone.utc)
|
|
184
|
+
|
|
185
|
+
# Auto-set started timestamp when moving to IN_PROGRESS
|
|
186
|
+
if status == SyncStatus.IN_PROGRESS and sync_operation.started_at is None:
|
|
187
|
+
sync_operation.started_at = datetime.now(timezone.utc)
|
|
188
|
+
|
|
189
|
+
await session.commit()
|
|
190
|
+
await session.refresh(sync_operation)
|
|
191
|
+
|
|
192
|
+
logger.debug(f"Successfully updated sync operation {run_id}")
|
|
193
|
+
return sync_operation
|
|
194
|
+
|
|
195
|
+
except SQLAlchemyError as e:
|
|
196
|
+
logger.error(
|
|
197
|
+
f"Database error updating sync operation {run_id}: {str(e)}", exc_info=True
|
|
198
|
+
)
|
|
199
|
+
await session.rollback()
|
|
200
|
+
raise
|
|
201
|
+
except Exception as e:
|
|
202
|
+
logger.error(
|
|
203
|
+
f"Unexpected error updating sync operation {run_id}: {str(e)}", exc_info=True
|
|
204
|
+
)
|
|
205
|
+
await session.rollback()
|
|
206
|
+
raise
|
|
207
|
+
|
|
208
|
+
# Use retry logic for the database operation
|
|
209
|
+
return await _retry_db_operation(_perform_update, run_id)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
async def mark_sync_started(run_id: str) -> Optional[SyncOperation]:
|
|
213
|
+
"""Convenience method to mark a sync operation as started."""
|
|
214
|
+
return await update_sync_operation(
|
|
215
|
+
run_id=run_id, status=SyncStatus.IN_PROGRESS, started_at=datetime.now(timezone.utc)
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
async def mark_sync_completed(
|
|
220
|
+
run_id: str,
|
|
221
|
+
records_downloaded: int = 0,
|
|
222
|
+
records_uploaded: int = 0,
|
|
223
|
+
bytes_downloaded: int = 0,
|
|
224
|
+
bytes_uploaded: int = 0,
|
|
225
|
+
dataset_sync_hashes: Optional[dict] = None,
|
|
226
|
+
) -> Optional[SyncOperation]:
|
|
227
|
+
"""Convenience method to mark a sync operation as completed successfully."""
|
|
228
|
+
return await update_sync_operation(
|
|
229
|
+
run_id=run_id,
|
|
230
|
+
status=SyncStatus.COMPLETED,
|
|
231
|
+
progress_percentage=100,
|
|
232
|
+
records_downloaded=records_downloaded,
|
|
233
|
+
records_uploaded=records_uploaded,
|
|
234
|
+
bytes_downloaded=bytes_downloaded,
|
|
235
|
+
bytes_uploaded=bytes_uploaded,
|
|
236
|
+
dataset_sync_hashes=dataset_sync_hashes,
|
|
237
|
+
completed_at=datetime.now(timezone.utc),
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
async def mark_sync_failed(run_id: str, error_message: str) -> Optional[SyncOperation]:
|
|
242
|
+
"""Convenience method to mark a sync operation as failed."""
|
|
243
|
+
return await update_sync_operation(
|
|
244
|
+
run_id=run_id,
|
|
245
|
+
status=SyncStatus.FAILED,
|
|
246
|
+
error_message=error_message,
|
|
247
|
+
completed_at=datetime.now(timezone.utc),
|
|
248
|
+
)
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from uuid import uuid4
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Optional, List
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
from sqlalchemy import (
|
|
6
|
+
Column,
|
|
7
|
+
Text,
|
|
8
|
+
DateTime,
|
|
9
|
+
UUID as SQLAlchemy_UUID,
|
|
10
|
+
Integer,
|
|
11
|
+
Enum as SQLEnum,
|
|
12
|
+
JSON,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from cognee.infrastructure.databases.relational import Base
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SyncStatus(str, Enum):
|
|
19
|
+
"""Enumeration of possible sync operation statuses."""
|
|
20
|
+
|
|
21
|
+
STARTED = "started"
|
|
22
|
+
IN_PROGRESS = "in_progress"
|
|
23
|
+
COMPLETED = "completed"
|
|
24
|
+
FAILED = "failed"
|
|
25
|
+
CANCELLED = "cancelled"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SyncOperation(Base):
|
|
29
|
+
"""
|
|
30
|
+
Database model for tracking sync operations.
|
|
31
|
+
|
|
32
|
+
This model stores information about background sync operations,
|
|
33
|
+
allowing users to monitor progress and query the status of their sync requests.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
__tablename__ = "sync_operations"
|
|
37
|
+
|
|
38
|
+
# Primary identifiers
|
|
39
|
+
id = Column(SQLAlchemy_UUID, primary_key=True, default=uuid4, doc="Database primary key")
|
|
40
|
+
run_id = Column(Text, unique=True, index=True, doc="Public run ID returned to users")
|
|
41
|
+
|
|
42
|
+
# Status and progress tracking
|
|
43
|
+
status = Column(
|
|
44
|
+
SQLEnum(SyncStatus), default=SyncStatus.STARTED, doc="Current status of the sync operation"
|
|
45
|
+
)
|
|
46
|
+
progress_percentage = Column(Integer, default=0, doc="Progress percentage (0-100)")
|
|
47
|
+
|
|
48
|
+
# Operation metadata
|
|
49
|
+
dataset_ids = Column(JSON, doc="Array of dataset IDs being synced")
|
|
50
|
+
dataset_names = Column(JSON, doc="Array of dataset names being synced")
|
|
51
|
+
user_id = Column(SQLAlchemy_UUID, index=True, doc="ID of the user who initiated the sync")
|
|
52
|
+
|
|
53
|
+
# Timing information
|
|
54
|
+
created_at = Column(
|
|
55
|
+
DateTime(timezone=True),
|
|
56
|
+
default=lambda: datetime.now(timezone.utc),
|
|
57
|
+
doc="When the sync was initiated",
|
|
58
|
+
)
|
|
59
|
+
started_at = Column(DateTime(timezone=True), doc="When the actual sync processing began")
|
|
60
|
+
completed_at = Column(
|
|
61
|
+
DateTime(timezone=True), doc="When the sync finished (success or failure)"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Operation details
|
|
65
|
+
total_records_to_sync = Column(Integer, doc="Total number of records to sync")
|
|
66
|
+
total_records_to_download = Column(Integer, doc="Total number of records to download")
|
|
67
|
+
total_records_to_upload = Column(Integer, doc="Total number of records to upload")
|
|
68
|
+
|
|
69
|
+
records_downloaded = Column(Integer, default=0, doc="Number of records successfully downloaded")
|
|
70
|
+
records_uploaded = Column(Integer, default=0, doc="Number of records successfully uploaded")
|
|
71
|
+
bytes_downloaded = Column(Integer, default=0, doc="Total bytes downloaded from cloud")
|
|
72
|
+
bytes_uploaded = Column(Integer, default=0, doc="Total bytes uploaded to cloud")
|
|
73
|
+
|
|
74
|
+
# Data lineage tracking per dataset
|
|
75
|
+
dataset_sync_hashes = Column(
|
|
76
|
+
JSON, doc="Mapping of dataset_id -> {uploaded: [hashes], downloaded: [hashes]}"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Error handling
|
|
80
|
+
error_message = Column(Text, doc="Error message if sync failed")
|
|
81
|
+
retry_count = Column(Integer, default=0, doc="Number of retry attempts")
|
|
82
|
+
|
|
83
|
+
def get_duration_seconds(self) -> Optional[float]:
|
|
84
|
+
"""Get the duration of the sync operation in seconds."""
|
|
85
|
+
if not self.created_at:
|
|
86
|
+
return None
|
|
87
|
+
|
|
88
|
+
end_time = self.completed_at or datetime.now(timezone.utc)
|
|
89
|
+
return (end_time - self.created_at).total_seconds()
|
|
90
|
+
|
|
91
|
+
def get_progress_info(self) -> dict:
|
|
92
|
+
"""Get comprehensive progress information."""
|
|
93
|
+
total_records_processed = (self.records_downloaded or 0) + (self.records_uploaded or 0)
|
|
94
|
+
total_bytes_transferred = (self.bytes_downloaded or 0) + (self.bytes_uploaded or 0)
|
|
95
|
+
|
|
96
|
+
return {
|
|
97
|
+
"status": self.status.value,
|
|
98
|
+
"progress_percentage": self.progress_percentage,
|
|
99
|
+
"records_processed": f"{total_records_processed}/{self.total_records_to_sync or 'unknown'}",
|
|
100
|
+
"records_downloaded": self.records_downloaded or 0,
|
|
101
|
+
"records_uploaded": self.records_uploaded or 0,
|
|
102
|
+
"bytes_transferred": total_bytes_transferred,
|
|
103
|
+
"bytes_downloaded": self.bytes_downloaded or 0,
|
|
104
|
+
"bytes_uploaded": self.bytes_uploaded or 0,
|
|
105
|
+
"duration_seconds": self.get_duration_seconds(),
|
|
106
|
+
"error_message": self.error_message,
|
|
107
|
+
"dataset_sync_hashes": self.dataset_sync_hashes or {},
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
def _get_all_sync_hashes(self) -> List[str]:
|
|
111
|
+
"""Get all content hashes for data created/modified during this sync operation."""
|
|
112
|
+
all_hashes = set()
|
|
113
|
+
dataset_hashes = self.dataset_sync_hashes or {}
|
|
114
|
+
|
|
115
|
+
for dataset_id, operations in dataset_hashes.items():
|
|
116
|
+
if isinstance(operations, dict):
|
|
117
|
+
all_hashes.update(operations.get("uploaded", []))
|
|
118
|
+
all_hashes.update(operations.get("downloaded", []))
|
|
119
|
+
|
|
120
|
+
return list(all_hashes)
|
|
121
|
+
|
|
122
|
+
def _get_dataset_sync_hashes(self, dataset_id: str) -> dict:
|
|
123
|
+
"""Get uploaded/downloaded hashes for a specific dataset."""
|
|
124
|
+
dataset_hashes = self.dataset_sync_hashes or {}
|
|
125
|
+
return dataset_hashes.get(dataset_id, {"uploaded": [], "downloaded": []})
|
|
126
|
+
|
|
127
|
+
def was_data_synced(self, content_hash: str, dataset_id: str = None) -> bool:
|
|
128
|
+
"""
|
|
129
|
+
Check if a specific piece of data was part of this sync operation.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
content_hash: The content hash to check for
|
|
133
|
+
dataset_id: Optional - check only within this dataset
|
|
134
|
+
"""
|
|
135
|
+
if dataset_id:
|
|
136
|
+
dataset_hashes = self.get_dataset_sync_hashes(dataset_id)
|
|
137
|
+
return content_hash in dataset_hashes.get(
|
|
138
|
+
"uploaded", []
|
|
139
|
+
) or content_hash in dataset_hashes.get("downloaded", [])
|
|
140
|
+
|
|
141
|
+
all_hashes = self.get_all_sync_hashes()
|
|
142
|
+
return content_hash in all_hashes
|
cognee/modules/users/__init__.py
CHANGED
|
@@ -4,4 +4,7 @@ from .delete_user import delete_user
|
|
|
4
4
|
from .get_default_user import get_default_user
|
|
5
5
|
from .get_user_by_email import get_user_by_email
|
|
6
6
|
from .create_default_user import create_default_user
|
|
7
|
-
from .get_authenticated_user import
|
|
7
|
+
from .get_authenticated_user import (
|
|
8
|
+
get_authenticated_user,
|
|
9
|
+
REQUIRE_AUTHENTICATION,
|
|
10
|
+
)
|
|
@@ -1,6 +1,10 @@
|
|
|
1
|
+
from uuid import uuid4
|
|
1
2
|
from fastapi_users.exceptions import UserAlreadyExists
|
|
2
|
-
|
|
3
|
+
|
|
3
4
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
5
|
+
from cognee.modules.notebooks.methods import create_notebook
|
|
6
|
+
from cognee.modules.notebooks.models.Notebook import NotebookCell
|
|
7
|
+
from cognee.modules.users.exceptions import TenantNotFoundError
|
|
4
8
|
from cognee.modules.users.get_user_manager import get_user_manager_context
|
|
5
9
|
from cognee.modules.users.get_user_db import get_user_db_context
|
|
6
10
|
from cognee.modules.users.models.User import UserCreate
|
|
@@ -56,6 +60,27 @@ async def create_user(
|
|
|
56
60
|
if auto_login:
|
|
57
61
|
await session.refresh(user)
|
|
58
62
|
|
|
63
|
+
await create_notebook(
|
|
64
|
+
user_id=user.id,
|
|
65
|
+
notebook_name="Welcome to cognee 🧠",
|
|
66
|
+
cells=[
|
|
67
|
+
NotebookCell(
|
|
68
|
+
id=uuid4(),
|
|
69
|
+
name="Welcome",
|
|
70
|
+
content="Cognee is your toolkit for turning text into a structured knowledge graph, optionally enhanced by ontologies, and then querying it with advanced retrieval techniques. This notebook will guide you through a simple example.",
|
|
71
|
+
type="markdown",
|
|
72
|
+
),
|
|
73
|
+
NotebookCell(
|
|
74
|
+
id=uuid4(),
|
|
75
|
+
name="Example",
|
|
76
|
+
content="",
|
|
77
|
+
type="code",
|
|
78
|
+
),
|
|
79
|
+
],
|
|
80
|
+
deletable=False,
|
|
81
|
+
session=session,
|
|
82
|
+
)
|
|
83
|
+
|
|
59
84
|
return user
|
|
60
85
|
except UserAlreadyExists as error:
|
|
61
86
|
print(f"User {email} already exists")
|
|
@@ -1,48 +1,42 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from fastapi import Depends, HTTPException
|
|
4
|
+
from ..models import User
|
|
1
5
|
from ..get_fastapi_users import get_fastapi_users
|
|
6
|
+
from .get_default_user import get_default_user
|
|
7
|
+
from cognee.shared.logging_utils import get_logger
|
|
2
8
|
|
|
3
9
|
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
get_authenticated_user = fastapi_users.current_user(active=True)
|
|
7
|
-
|
|
8
|
-
# from types import SimpleNamespace
|
|
9
|
-
|
|
10
|
-
# from ..get_fastapi_users import get_fastapi_users
|
|
11
|
-
# from fastapi import HTTPException, Security
|
|
12
|
-
# from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
13
|
-
# import os
|
|
14
|
-
# import jwt
|
|
15
|
-
|
|
16
|
-
# from uuid import UUID
|
|
17
|
-
|
|
18
|
-
# fastapi_users = get_fastapi_users()
|
|
10
|
+
logger = get_logger("get_authenticated_user")
|
|
19
11
|
|
|
20
|
-
#
|
|
21
|
-
|
|
12
|
+
# Check environment variable to determine authentication requirement
|
|
13
|
+
REQUIRE_AUTHENTICATION = (
|
|
14
|
+
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
|
|
15
|
+
or os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true"
|
|
16
|
+
)
|
|
22
17
|
|
|
18
|
+
fastapi_users = get_fastapi_users()
|
|
23
19
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
#
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
#
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
# except jwt.InvalidTokenError:
|
|
48
|
-
# raise HTTPException(status_code=401, detail="Invalid token")
|
|
20
|
+
_auth_dependency = fastapi_users.current_user(active=True, optional=not REQUIRE_AUTHENTICATION)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
async def get_authenticated_user(
|
|
24
|
+
user: Optional[User] = Depends(_auth_dependency),
|
|
25
|
+
) -> User:
|
|
26
|
+
"""
|
|
27
|
+
Get authenticated user with environment-controlled behavior:
|
|
28
|
+
- If REQUIRE_AUTHENTICATION=true: Enforces authentication (raises 401 if not authenticated)
|
|
29
|
+
- If REQUIRE_AUTHENTICATION=false: Falls back to default user if not authenticated
|
|
30
|
+
|
|
31
|
+
Always returns a User object for consistent typing.
|
|
32
|
+
"""
|
|
33
|
+
if user is None:
|
|
34
|
+
# When authentication is optional and user is None, use default user
|
|
35
|
+
try:
|
|
36
|
+
user = await get_default_user()
|
|
37
|
+
except Exception as e:
|
|
38
|
+
# Convert any get_default_user failure into a proper HTTP 500 error
|
|
39
|
+
logger.error(f"Failed to create default user: {str(e)}")
|
|
40
|
+
raise HTTPException(status_code=500, detail=f"Failed to create default user: {str(e)}")
|
|
41
|
+
|
|
42
|
+
return user
|
|
@@ -29,7 +29,9 @@ async def get_default_user() -> SimpleNamespace:
|
|
|
29
29
|
|
|
30
30
|
# We return a SimpleNamespace to have the same user type as our SaaS
|
|
31
31
|
# SimpleNamespace is just a dictionary which can be accessed through attributes
|
|
32
|
-
auth_data = SimpleNamespace(
|
|
32
|
+
auth_data = SimpleNamespace(
|
|
33
|
+
id=user.id, email=user.email, tenant_id=user.tenant_id, roles=[]
|
|
34
|
+
)
|
|
33
35
|
return auth_data
|
|
34
36
|
except Exception as error:
|
|
35
37
|
if "principals" in str(error.args):
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from uuid import UUID
|
|
2
|
+
from typing import Optional
|
|
2
3
|
from cognee.modules.data.models.Dataset import Dataset
|
|
3
4
|
from cognee.modules.users.permissions.methods.get_all_user_permission_datasets import (
|
|
4
5
|
get_all_user_permission_datasets,
|
|
@@ -8,7 +9,7 @@ from cognee.modules.users.methods import get_user
|
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
async def get_specific_user_permission_datasets(
|
|
11
|
-
user_id: UUID, permission_type: str, dataset_ids: list[UUID] = None
|
|
12
|
+
user_id: UUID, permission_type: str, dataset_ids: Optional[list[UUID]] = None
|
|
12
13
|
) -> list[Dataset]:
|
|
13
14
|
"""
|
|
14
15
|
Return a list of datasets user has given permission for. If a list of datasets is provided,
|