cognee 0.2.4__py3-none-any.whl → 0.3.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +1 -0
- cognee/api/client.py +28 -3
- cognee/api/health.py +10 -13
- cognee/api/v1/add/add.py +3 -1
- cognee/api/v1/add/routers/get_add_router.py +12 -37
- cognee/api/v1/cloud/routers/__init__.py +1 -0
- cognee/api/v1/cloud/routers/get_checks_router.py +23 -0
- cognee/api/v1/cognify/code_graph_pipeline.py +9 -4
- cognee/api/v1/cognify/cognify.py +50 -3
- cognee/api/v1/cognify/routers/get_cognify_router.py +1 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +15 -4
- cognee/api/v1/memify/__init__.py +0 -0
- cognee/api/v1/memify/routers/__init__.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +100 -0
- cognee/api/v1/notebooks/routers/__init__.py +1 -0
- cognee/api/v1/notebooks/routers/get_notebooks_router.py +96 -0
- cognee/api/v1/search/routers/get_search_router.py +20 -1
- cognee/api/v1/search/search.py +11 -4
- cognee/api/v1/sync/__init__.py +17 -0
- cognee/api/v1/sync/routers/__init__.py +3 -0
- cognee/api/v1/sync/routers/get_sync_router.py +241 -0
- cognee/api/v1/sync/sync.py +877 -0
- cognee/api/v1/users/routers/get_auth_router.py +13 -1
- cognee/base_config.py +10 -1
- cognee/infrastructure/databases/graph/config.py +10 -4
- cognee/infrastructure/databases/graph/kuzu/adapter.py +135 -0
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +89 -0
- cognee/infrastructure/databases/relational/__init__.py +2 -0
- cognee/infrastructure/databases/relational/get_async_session.py +15 -0
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +6 -1
- cognee/infrastructure/databases/relational/with_async_session.py +25 -0
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +1 -1
- cognee/infrastructure/databases/vector/config.py +13 -6
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +1 -1
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +2 -6
- cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +4 -1
- cognee/infrastructure/files/storage/LocalFileStorage.py +9 -0
- cognee/infrastructure/files/storage/S3FileStorage.py +5 -0
- cognee/infrastructure/files/storage/StorageManager.py +7 -1
- cognee/infrastructure/files/storage/storage.py +16 -0
- cognee/infrastructure/llm/LLMGateway.py +18 -0
- cognee/infrastructure/llm/config.py +4 -2
- cognee/infrastructure/llm/prompts/extract_query_time.txt +15 -0
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +25 -0
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +30 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +2 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_event_entities.py +44 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/__init__.py +1 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_event_graph.py +46 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -1
- cognee/infrastructure/utils/run_sync.py +8 -1
- cognee/modules/chunking/models/DocumentChunk.py +4 -3
- cognee/modules/cloud/exceptions/CloudApiKeyMissingError.py +15 -0
- cognee/modules/cloud/exceptions/CloudConnectionError.py +15 -0
- cognee/modules/cloud/exceptions/__init__.py +2 -0
- cognee/modules/cloud/operations/__init__.py +1 -0
- cognee/modules/cloud/operations/check_api_key.py +25 -0
- cognee/modules/data/deletion/prune_system.py +1 -1
- cognee/modules/data/methods/check_dataset_name.py +1 -1
- cognee/modules/data/methods/get_dataset_data.py +1 -1
- cognee/modules/data/methods/load_or_create_datasets.py +1 -1
- cognee/modules/engine/models/Event.py +16 -0
- cognee/modules/engine/models/Interval.py +8 -0
- cognee/modules/engine/models/Timestamp.py +13 -0
- cognee/modules/engine/models/__init__.py +3 -0
- cognee/modules/engine/utils/__init__.py +2 -0
- cognee/modules/engine/utils/generate_event_datapoint.py +46 -0
- cognee/modules/engine/utils/generate_timestamp_datapoint.py +51 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +2 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/resolve_edges_to_text.py +71 -0
- cognee/modules/memify/__init__.py +1 -0
- cognee/modules/memify/memify.py +118 -0
- cognee/modules/notebooks/methods/__init__.py +5 -0
- cognee/modules/notebooks/methods/create_notebook.py +26 -0
- cognee/modules/notebooks/methods/delete_notebook.py +13 -0
- cognee/modules/notebooks/methods/get_notebook.py +21 -0
- cognee/modules/notebooks/methods/get_notebooks.py +18 -0
- cognee/modules/notebooks/methods/update_notebook.py +17 -0
- cognee/modules/notebooks/models/Notebook.py +53 -0
- cognee/modules/notebooks/models/__init__.py +1 -0
- cognee/modules/notebooks/operations/__init__.py +1 -0
- cognee/modules/notebooks/operations/run_in_local_sandbox.py +55 -0
- cognee/modules/pipelines/layers/reset_dataset_pipeline_run_status.py +19 -3
- cognee/modules/pipelines/operations/pipeline.py +1 -0
- cognee/modules/pipelines/operations/run_tasks.py +17 -41
- cognee/modules/retrieval/base_graph_retriever.py +18 -0
- cognee/modules/retrieval/base_retriever.py +1 -1
- cognee/modules/retrieval/code_retriever.py +8 -0
- cognee/modules/retrieval/coding_rules_retriever.py +31 -0
- cognee/modules/retrieval/completion_retriever.py +9 -3
- cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +1 -0
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +23 -14
- cognee/modules/retrieval/graph_completion_cot_retriever.py +21 -11
- cognee/modules/retrieval/graph_completion_retriever.py +32 -65
- cognee/modules/retrieval/graph_summary_completion_retriever.py +3 -1
- cognee/modules/retrieval/insights_retriever.py +14 -3
- cognee/modules/retrieval/summaries_retriever.py +1 -1
- cognee/modules/retrieval/temporal_retriever.py +152 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +7 -32
- cognee/modules/retrieval/utils/completion.py +10 -3
- cognee/modules/search/methods/get_search_type_tools.py +168 -0
- cognee/modules/search/methods/no_access_control_search.py +47 -0
- cognee/modules/search/methods/search.py +219 -139
- cognee/modules/search/types/SearchResult.py +21 -0
- cognee/modules/search/types/SearchType.py +2 -0
- cognee/modules/search/types/__init__.py +1 -0
- cognee/modules/search/utils/__init__.py +2 -0
- cognee/modules/search/utils/prepare_search_result.py +41 -0
- cognee/modules/search/utils/transform_context_to_graph.py +38 -0
- cognee/modules/sync/__init__.py +1 -0
- cognee/modules/sync/methods/__init__.py +23 -0
- cognee/modules/sync/methods/create_sync_operation.py +53 -0
- cognee/modules/sync/methods/get_sync_operation.py +107 -0
- cognee/modules/sync/methods/update_sync_operation.py +248 -0
- cognee/modules/sync/models/SyncOperation.py +142 -0
- cognee/modules/sync/models/__init__.py +3 -0
- cognee/modules/users/__init__.py +0 -1
- cognee/modules/users/methods/__init__.py +4 -1
- cognee/modules/users/methods/create_user.py +26 -1
- cognee/modules/users/methods/get_authenticated_user.py +36 -42
- cognee/modules/users/methods/get_default_user.py +3 -1
- cognee/modules/users/permissions/methods/get_specific_user_permission_datasets.py +2 -1
- cognee/root_dir.py +19 -0
- cognee/shared/logging_utils.py +1 -1
- cognee/tasks/codingagents/__init__.py +0 -0
- cognee/tasks/codingagents/coding_rule_associations.py +127 -0
- cognee/tasks/ingestion/save_data_item_to_storage.py +23 -0
- cognee/tasks/memify/__init__.py +2 -0
- cognee/tasks/memify/extract_subgraph.py +7 -0
- cognee/tasks/memify/extract_subgraph_chunks.py +11 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +52 -27
- cognee/tasks/temporal_graph/__init__.py +1 -0
- cognee/tasks/temporal_graph/add_entities_to_event.py +85 -0
- cognee/tasks/temporal_graph/enrich_events.py +34 -0
- cognee/tasks/temporal_graph/extract_events_and_entities.py +32 -0
- cognee/tasks/temporal_graph/extract_knowledge_graph_from_events.py +41 -0
- cognee/tasks/temporal_graph/models.py +49 -0
- cognee/tests/test_kuzu.py +4 -4
- cognee/tests/test_neo4j.py +4 -4
- cognee/tests/test_permissions.py +3 -3
- cognee/tests/test_relational_db_migration.py +7 -5
- cognee/tests/test_search_db.py +18 -24
- cognee/tests/test_temporal_graph.py +167 -0
- cognee/tests/unit/api/__init__.py +1 -0
- cognee/tests/unit/api/test_conditional_authentication_endpoints.py +246 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +13 -16
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +11 -16
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +5 -4
- cognee/tests/unit/modules/retrieval/insights_retriever_test.py +4 -2
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +225 -0
- cognee/tests/unit/modules/users/__init__.py +1 -0
- cognee/tests/unit/modules/users/test_conditional_authentication.py +277 -0
- cognee/tests/unit/processing/utils/utils_test.py +20 -1
- {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/METADATA +8 -6
- {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/RECORD +162 -89
- cognee/tests/unit/modules/search/search_methods_test.py +0 -225
- {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from uuid import UUID
|
|
2
|
+
from pydantic import BaseModel
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SearchResultDataset(BaseModel):
|
|
7
|
+
id: UUID
|
|
8
|
+
name: str
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CombinedSearchResult(BaseModel):
|
|
12
|
+
result: Optional[Any]
|
|
13
|
+
context: Dict[str, Any]
|
|
14
|
+
graphs: Optional[Dict[str, Any]] = {}
|
|
15
|
+
datasets: Optional[List[SearchResultDataset]] = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SearchResult(BaseModel):
|
|
19
|
+
search_result: Any
|
|
20
|
+
dataset_id: Optional[UUID]
|
|
21
|
+
dataset_name: Optional[str]
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import List, cast
|
|
2
|
+
|
|
3
|
+
from cognee.modules.graph.utils import resolve_edges_to_text
|
|
4
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
5
|
+
from cognee.modules.search.utils.transform_context_to_graph import transform_context_to_graph
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
async def prepare_search_result(search_result):
|
|
9
|
+
result, context, datasets = search_result
|
|
10
|
+
|
|
11
|
+
graphs = None
|
|
12
|
+
result_graph = None
|
|
13
|
+
context_texts = {}
|
|
14
|
+
|
|
15
|
+
if isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge):
|
|
16
|
+
context_graph = transform_context_to_graph(context)
|
|
17
|
+
|
|
18
|
+
graphs = {
|
|
19
|
+
"*": context_graph,
|
|
20
|
+
}
|
|
21
|
+
context_texts = {
|
|
22
|
+
"*": await resolve_edges_to_text(context),
|
|
23
|
+
}
|
|
24
|
+
elif isinstance(context, str):
|
|
25
|
+
context_texts = {
|
|
26
|
+
"*": context,
|
|
27
|
+
}
|
|
28
|
+
elif isinstance(context, List) and len(context) > 0 and isinstance(context[0], str):
|
|
29
|
+
context_texts = {
|
|
30
|
+
"*": "\n".join(cast(List[str], context)),
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
if isinstance(result, List) and len(result) > 0 and isinstance(result[0], Edge):
|
|
34
|
+
result_graph = transform_context_to_graph(result)
|
|
35
|
+
|
|
36
|
+
return {
|
|
37
|
+
"result": result_graph or result,
|
|
38
|
+
"graphs": graphs,
|
|
39
|
+
"context": context_texts,
|
|
40
|
+
"datasets": datasets,
|
|
41
|
+
}
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def transform_context_to_graph(context: List[Edge]):
|
|
7
|
+
nodes = {}
|
|
8
|
+
edges = {}
|
|
9
|
+
|
|
10
|
+
for triplet in context:
|
|
11
|
+
nodes[triplet.node1.id] = {
|
|
12
|
+
"id": triplet.node1.id,
|
|
13
|
+
"label": triplet.node1.attributes["name"]
|
|
14
|
+
if "name" in triplet.node1.attributes
|
|
15
|
+
else triplet.node1.id,
|
|
16
|
+
"type": triplet.node1.attributes["type"],
|
|
17
|
+
"attributes": triplet.node2.attributes,
|
|
18
|
+
}
|
|
19
|
+
nodes[triplet.node2.id] = {
|
|
20
|
+
"id": triplet.node2.id,
|
|
21
|
+
"label": triplet.node2.attributes["name"]
|
|
22
|
+
if "name" in triplet.node2.attributes
|
|
23
|
+
else triplet.node2.id,
|
|
24
|
+
"type": triplet.node2.attributes["type"],
|
|
25
|
+
"attributes": triplet.node2.attributes,
|
|
26
|
+
}
|
|
27
|
+
edges[
|
|
28
|
+
f"{triplet.node1.id}_{triplet.attributes['relationship_name']}_{triplet.node2.id}"
|
|
29
|
+
] = {
|
|
30
|
+
"source": triplet.node1.id,
|
|
31
|
+
"target": triplet.node2.id,
|
|
32
|
+
"label": triplet.attributes["relationship_name"],
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
return {
|
|
36
|
+
"nodes": list(nodes.values()),
|
|
37
|
+
"edges": list(edges.values()),
|
|
38
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Sync module for tracking sync operations
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from .create_sync_operation import create_sync_operation
|
|
2
|
+
from .get_sync_operation import (
|
|
3
|
+
get_sync_operation,
|
|
4
|
+
get_user_sync_operations,
|
|
5
|
+
get_running_sync_operations_for_user,
|
|
6
|
+
)
|
|
7
|
+
from .update_sync_operation import (
|
|
8
|
+
update_sync_operation,
|
|
9
|
+
mark_sync_started,
|
|
10
|
+
mark_sync_completed,
|
|
11
|
+
mark_sync_failed,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"create_sync_operation",
|
|
16
|
+
"get_sync_operation",
|
|
17
|
+
"get_user_sync_operations",
|
|
18
|
+
"get_running_sync_operations_for_user",
|
|
19
|
+
"update_sync_operation",
|
|
20
|
+
"mark_sync_started",
|
|
21
|
+
"mark_sync_completed",
|
|
22
|
+
"mark_sync_failed",
|
|
23
|
+
]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from uuid import UUID
|
|
2
|
+
from typing import Optional, List
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from cognee.modules.sync.models import SyncOperation, SyncStatus
|
|
5
|
+
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
async def create_sync_operation(
|
|
9
|
+
run_id: str,
|
|
10
|
+
dataset_ids: List[UUID],
|
|
11
|
+
dataset_names: List[str],
|
|
12
|
+
user_id: UUID,
|
|
13
|
+
total_records_to_sync: Optional[int] = None,
|
|
14
|
+
total_records_to_download: Optional[int] = None,
|
|
15
|
+
total_records_to_upload: Optional[int] = None,
|
|
16
|
+
) -> SyncOperation:
|
|
17
|
+
"""
|
|
18
|
+
Create a new sync operation record in the database.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
run_id: Unique public identifier for this sync operation
|
|
22
|
+
dataset_ids: List of dataset UUIDs being synced
|
|
23
|
+
dataset_names: List of dataset names being synced
|
|
24
|
+
user_id: UUID of the user who initiated the sync
|
|
25
|
+
total_records_to_sync: Total number of records to sync (if known)
|
|
26
|
+
total_records_to_download: Total number of records to download (if known)
|
|
27
|
+
total_records_to_upload: Total number of records to upload (if known)
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
SyncOperation: The created sync operation record
|
|
31
|
+
"""
|
|
32
|
+
db_engine = get_relational_engine()
|
|
33
|
+
|
|
34
|
+
sync_operation = SyncOperation(
|
|
35
|
+
run_id=run_id,
|
|
36
|
+
dataset_ids=[
|
|
37
|
+
str(uuid) for uuid in dataset_ids
|
|
38
|
+
], # Convert UUIDs to strings for JSON storage
|
|
39
|
+
dataset_names=dataset_names,
|
|
40
|
+
user_id=user_id,
|
|
41
|
+
status=SyncStatus.STARTED,
|
|
42
|
+
total_records_to_sync=total_records_to_sync,
|
|
43
|
+
total_records_to_download=total_records_to_download,
|
|
44
|
+
total_records_to_upload=total_records_to_upload,
|
|
45
|
+
created_at=datetime.now(timezone.utc),
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
async with db_engine.get_async_session() as session:
|
|
49
|
+
session.add(sync_operation)
|
|
50
|
+
await session.commit()
|
|
51
|
+
await session.refresh(sync_operation)
|
|
52
|
+
|
|
53
|
+
return sync_operation
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from uuid import UUID
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
from sqlalchemy import select, desc, and_
|
|
4
|
+
from cognee.modules.sync.models import SyncOperation, SyncStatus
|
|
5
|
+
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
async def get_sync_operation(run_id: str) -> Optional[SyncOperation]:
|
|
9
|
+
"""
|
|
10
|
+
Get a sync operation by its run_id.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
run_id: The public run_id of the sync operation
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
SyncOperation: The sync operation record, or None if not found
|
|
17
|
+
"""
|
|
18
|
+
db_engine = get_relational_engine()
|
|
19
|
+
|
|
20
|
+
async with db_engine.get_async_session() as session:
|
|
21
|
+
query = select(SyncOperation).where(SyncOperation.run_id == run_id)
|
|
22
|
+
result = await session.execute(query)
|
|
23
|
+
return result.scalars().first()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def get_user_sync_operations(
|
|
27
|
+
user_id: UUID, limit: int = 50, offset: int = 0
|
|
28
|
+
) -> List[SyncOperation]:
|
|
29
|
+
"""
|
|
30
|
+
Get sync operations for a specific user, ordered by most recent first.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
user_id: UUID of the user
|
|
34
|
+
limit: Maximum number of records to return
|
|
35
|
+
offset: Number of records to skip
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
List[SyncOperation]: List of sync operations for the user
|
|
39
|
+
"""
|
|
40
|
+
db_engine = get_relational_engine()
|
|
41
|
+
|
|
42
|
+
async with db_engine.get_async_session() as session:
|
|
43
|
+
query = (
|
|
44
|
+
select(SyncOperation)
|
|
45
|
+
.where(SyncOperation.user_id == user_id)
|
|
46
|
+
.order_by(desc(SyncOperation.created_at))
|
|
47
|
+
.limit(limit)
|
|
48
|
+
.offset(offset)
|
|
49
|
+
)
|
|
50
|
+
result = await session.execute(query)
|
|
51
|
+
return list(result.scalars().all())
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def get_sync_operations_by_dataset(
|
|
55
|
+
dataset_id: UUID, limit: int = 50, offset: int = 0
|
|
56
|
+
) -> List[SyncOperation]:
|
|
57
|
+
"""
|
|
58
|
+
Get sync operations for a specific dataset.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
dataset_id: UUID of the dataset
|
|
62
|
+
limit: Maximum number of records to return
|
|
63
|
+
offset: Number of records to skip
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
List[SyncOperation]: List of sync operations for the dataset
|
|
67
|
+
"""
|
|
68
|
+
db_engine = get_relational_engine()
|
|
69
|
+
|
|
70
|
+
async with db_engine.get_async_session() as session:
|
|
71
|
+
query = (
|
|
72
|
+
select(SyncOperation)
|
|
73
|
+
.where(SyncOperation.dataset_id == dataset_id)
|
|
74
|
+
.order_by(desc(SyncOperation.created_at))
|
|
75
|
+
.limit(limit)
|
|
76
|
+
.offset(offset)
|
|
77
|
+
)
|
|
78
|
+
result = await session.execute(query)
|
|
79
|
+
return list(result.scalars().all())
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
async def get_running_sync_operations_for_user(user_id: UUID) -> List[SyncOperation]:
|
|
83
|
+
"""
|
|
84
|
+
Get all currently running sync operations for a specific user.
|
|
85
|
+
Checks for operations with STARTED or IN_PROGRESS status.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
user_id: UUID of the user
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
List[SyncOperation]: List of running sync operations for the user
|
|
92
|
+
"""
|
|
93
|
+
db_engine = get_relational_engine()
|
|
94
|
+
|
|
95
|
+
async with db_engine.get_async_session() as session:
|
|
96
|
+
query = (
|
|
97
|
+
select(SyncOperation)
|
|
98
|
+
.where(
|
|
99
|
+
and_(
|
|
100
|
+
SyncOperation.user_id == user_id,
|
|
101
|
+
SyncOperation.status.in_([SyncStatus.STARTED, SyncStatus.IN_PROGRESS]),
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
.order_by(desc(SyncOperation.created_at))
|
|
105
|
+
)
|
|
106
|
+
result = await session.execute(query)
|
|
107
|
+
return list(result.scalars().all())
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import Optional, List
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from sqlalchemy import select
|
|
5
|
+
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError, TimeoutError
|
|
6
|
+
from cognee.modules.sync.models import SyncOperation, SyncStatus
|
|
7
|
+
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
8
|
+
from cognee.shared.logging_utils import get_logger
|
|
9
|
+
from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
|
|
10
|
+
|
|
11
|
+
logger = get_logger("sync.db_operations")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
async def _retry_db_operation(operation_func, run_id: str, max_retries: int = 3):
|
|
15
|
+
"""
|
|
16
|
+
Retry database operations with exponential backoff for transient failures.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
operation_func: Async function to retry
|
|
20
|
+
run_id: Run ID for logging context
|
|
21
|
+
max_retries: Maximum number of retry attempts
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Result of the operation function
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
Exception: Re-raises the last exception if all retries fail
|
|
28
|
+
"""
|
|
29
|
+
attempt = 0
|
|
30
|
+
last_exception = None
|
|
31
|
+
|
|
32
|
+
while attempt < max_retries:
|
|
33
|
+
try:
|
|
34
|
+
return await operation_func()
|
|
35
|
+
except (DisconnectionError, OperationalError, TimeoutError) as e:
|
|
36
|
+
attempt += 1
|
|
37
|
+
last_exception = e
|
|
38
|
+
|
|
39
|
+
if attempt >= max_retries:
|
|
40
|
+
logger.error(
|
|
41
|
+
f"Database operation failed after {max_retries} attempts for run_id {run_id}: {str(e)}"
|
|
42
|
+
)
|
|
43
|
+
break
|
|
44
|
+
|
|
45
|
+
backoff_time = calculate_backoff(attempt - 1) # calculate_backoff is 0-indexed
|
|
46
|
+
logger.warning(
|
|
47
|
+
f"Database operation failed for run_id {run_id}, retrying in {backoff_time:.2f}s (attempt {attempt}/{max_retries}): {str(e)}"
|
|
48
|
+
)
|
|
49
|
+
await asyncio.sleep(backoff_time)
|
|
50
|
+
|
|
51
|
+
except Exception as e:
|
|
52
|
+
# Non-transient errors should not be retried
|
|
53
|
+
logger.error(f"Non-retryable database error for run_id {run_id}: {str(e)}")
|
|
54
|
+
raise
|
|
55
|
+
|
|
56
|
+
# If we get here, all retries failed
|
|
57
|
+
raise last_exception
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
async def update_sync_operation(
|
|
61
|
+
run_id: str,
|
|
62
|
+
status: Optional[SyncStatus] = None,
|
|
63
|
+
progress_percentage: Optional[int] = None,
|
|
64
|
+
records_downloaded: Optional[int] = None,
|
|
65
|
+
total_records_to_sync: Optional[int] = None,
|
|
66
|
+
total_records_to_download: Optional[int] = None,
|
|
67
|
+
total_records_to_upload: Optional[int] = None,
|
|
68
|
+
records_uploaded: Optional[int] = None,
|
|
69
|
+
bytes_downloaded: Optional[int] = None,
|
|
70
|
+
bytes_uploaded: Optional[int] = None,
|
|
71
|
+
dataset_sync_hashes: Optional[dict] = None,
|
|
72
|
+
error_message: Optional[str] = None,
|
|
73
|
+
retry_count: Optional[int] = None,
|
|
74
|
+
started_at: Optional[datetime] = None,
|
|
75
|
+
completed_at: Optional[datetime] = None,
|
|
76
|
+
) -> Optional[SyncOperation]:
|
|
77
|
+
"""
|
|
78
|
+
Update a sync operation record with new status/progress information.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
run_id: The public run_id of the sync operation to update
|
|
82
|
+
status: New status for the operation
|
|
83
|
+
progress_percentage: Progress percentage (0-100)
|
|
84
|
+
records_downloaded: Number of records downloaded so far
|
|
85
|
+
total_records_to_sync: Total number of records that need to be synced
|
|
86
|
+
total_records_to_download: Total number of records to download from cloud
|
|
87
|
+
total_records_to_upload: Total number of records to upload to cloud
|
|
88
|
+
records_uploaded: Number of records uploaded so far
|
|
89
|
+
bytes_downloaded: Total bytes downloaded from cloud
|
|
90
|
+
bytes_uploaded: Total bytes uploaded to cloud
|
|
91
|
+
dataset_sync_hashes: Dict mapping dataset_id -> {uploaded: [hashes], downloaded: [hashes]}
|
|
92
|
+
error_message: Error message if operation failed
|
|
93
|
+
retry_count: Number of retry attempts
|
|
94
|
+
started_at: When the actual processing started
|
|
95
|
+
completed_at: When the operation completed (success or failure)
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
SyncOperation: The updated sync operation record, or None if not found
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
async def _perform_update():
|
|
102
|
+
db_engine = get_relational_engine()
|
|
103
|
+
|
|
104
|
+
async with db_engine.get_async_session() as session:
|
|
105
|
+
try:
|
|
106
|
+
# Find the sync operation
|
|
107
|
+
query = select(SyncOperation).where(SyncOperation.run_id == run_id)
|
|
108
|
+
result = await session.execute(query)
|
|
109
|
+
sync_operation = result.scalars().first()
|
|
110
|
+
|
|
111
|
+
if not sync_operation:
|
|
112
|
+
logger.warning(f"Sync operation not found for run_id: {run_id}")
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
# Log what we're updating for debugging
|
|
116
|
+
updates = []
|
|
117
|
+
if status is not None:
|
|
118
|
+
updates.append(f"status={status.value}")
|
|
119
|
+
if progress_percentage is not None:
|
|
120
|
+
updates.append(f"progress={progress_percentage}%")
|
|
121
|
+
if records_downloaded is not None:
|
|
122
|
+
updates.append(f"downloaded={records_downloaded}")
|
|
123
|
+
if records_uploaded is not None:
|
|
124
|
+
updates.append(f"uploaded={records_uploaded}")
|
|
125
|
+
if total_records_to_sync is not None:
|
|
126
|
+
updates.append(f"total_sync={total_records_to_sync}")
|
|
127
|
+
if total_records_to_download is not None:
|
|
128
|
+
updates.append(f"total_download={total_records_to_download}")
|
|
129
|
+
if total_records_to_upload is not None:
|
|
130
|
+
updates.append(f"total_upload={total_records_to_upload}")
|
|
131
|
+
|
|
132
|
+
if updates:
|
|
133
|
+
logger.debug(f"Updating sync operation {run_id}: {', '.join(updates)}")
|
|
134
|
+
|
|
135
|
+
# Update fields that were provided
|
|
136
|
+
if status is not None:
|
|
137
|
+
sync_operation.status = status
|
|
138
|
+
|
|
139
|
+
if progress_percentage is not None:
|
|
140
|
+
sync_operation.progress_percentage = max(0, min(100, progress_percentage))
|
|
141
|
+
|
|
142
|
+
if records_downloaded is not None:
|
|
143
|
+
sync_operation.records_downloaded = records_downloaded
|
|
144
|
+
|
|
145
|
+
if records_uploaded is not None:
|
|
146
|
+
sync_operation.records_uploaded = records_uploaded
|
|
147
|
+
|
|
148
|
+
if total_records_to_sync is not None:
|
|
149
|
+
sync_operation.total_records_to_sync = total_records_to_sync
|
|
150
|
+
|
|
151
|
+
if total_records_to_download is not None:
|
|
152
|
+
sync_operation.total_records_to_download = total_records_to_download
|
|
153
|
+
|
|
154
|
+
if total_records_to_upload is not None:
|
|
155
|
+
sync_operation.total_records_to_upload = total_records_to_upload
|
|
156
|
+
|
|
157
|
+
if bytes_downloaded is not None:
|
|
158
|
+
sync_operation.bytes_downloaded = bytes_downloaded
|
|
159
|
+
|
|
160
|
+
if bytes_uploaded is not None:
|
|
161
|
+
sync_operation.bytes_uploaded = bytes_uploaded
|
|
162
|
+
|
|
163
|
+
if dataset_sync_hashes is not None:
|
|
164
|
+
sync_operation.dataset_sync_hashes = dataset_sync_hashes
|
|
165
|
+
|
|
166
|
+
if error_message is not None:
|
|
167
|
+
sync_operation.error_message = error_message
|
|
168
|
+
|
|
169
|
+
if retry_count is not None:
|
|
170
|
+
sync_operation.retry_count = retry_count
|
|
171
|
+
|
|
172
|
+
if started_at is not None:
|
|
173
|
+
sync_operation.started_at = started_at
|
|
174
|
+
|
|
175
|
+
if completed_at is not None:
|
|
176
|
+
sync_operation.completed_at = completed_at
|
|
177
|
+
|
|
178
|
+
# Auto-set completion timestamp for terminal statuses
|
|
179
|
+
if (
|
|
180
|
+
status in [SyncStatus.COMPLETED, SyncStatus.FAILED, SyncStatus.CANCELLED]
|
|
181
|
+
and completed_at is None
|
|
182
|
+
):
|
|
183
|
+
sync_operation.completed_at = datetime.now(timezone.utc)
|
|
184
|
+
|
|
185
|
+
# Auto-set started timestamp when moving to IN_PROGRESS
|
|
186
|
+
if status == SyncStatus.IN_PROGRESS and sync_operation.started_at is None:
|
|
187
|
+
sync_operation.started_at = datetime.now(timezone.utc)
|
|
188
|
+
|
|
189
|
+
await session.commit()
|
|
190
|
+
await session.refresh(sync_operation)
|
|
191
|
+
|
|
192
|
+
logger.debug(f"Successfully updated sync operation {run_id}")
|
|
193
|
+
return sync_operation
|
|
194
|
+
|
|
195
|
+
except SQLAlchemyError as e:
|
|
196
|
+
logger.error(
|
|
197
|
+
f"Database error updating sync operation {run_id}: {str(e)}", exc_info=True
|
|
198
|
+
)
|
|
199
|
+
await session.rollback()
|
|
200
|
+
raise
|
|
201
|
+
except Exception as e:
|
|
202
|
+
logger.error(
|
|
203
|
+
f"Unexpected error updating sync operation {run_id}: {str(e)}", exc_info=True
|
|
204
|
+
)
|
|
205
|
+
await session.rollback()
|
|
206
|
+
raise
|
|
207
|
+
|
|
208
|
+
# Use retry logic for the database operation
|
|
209
|
+
return await _retry_db_operation(_perform_update, run_id)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
async def mark_sync_started(run_id: str) -> Optional[SyncOperation]:
|
|
213
|
+
"""Convenience method to mark a sync operation as started."""
|
|
214
|
+
return await update_sync_operation(
|
|
215
|
+
run_id=run_id, status=SyncStatus.IN_PROGRESS, started_at=datetime.now(timezone.utc)
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
async def mark_sync_completed(
|
|
220
|
+
run_id: str,
|
|
221
|
+
records_downloaded: int = 0,
|
|
222
|
+
records_uploaded: int = 0,
|
|
223
|
+
bytes_downloaded: int = 0,
|
|
224
|
+
bytes_uploaded: int = 0,
|
|
225
|
+
dataset_sync_hashes: Optional[dict] = None,
|
|
226
|
+
) -> Optional[SyncOperation]:
|
|
227
|
+
"""Convenience method to mark a sync operation as completed successfully."""
|
|
228
|
+
return await update_sync_operation(
|
|
229
|
+
run_id=run_id,
|
|
230
|
+
status=SyncStatus.COMPLETED,
|
|
231
|
+
progress_percentage=100,
|
|
232
|
+
records_downloaded=records_downloaded,
|
|
233
|
+
records_uploaded=records_uploaded,
|
|
234
|
+
bytes_downloaded=bytes_downloaded,
|
|
235
|
+
bytes_uploaded=bytes_uploaded,
|
|
236
|
+
dataset_sync_hashes=dataset_sync_hashes,
|
|
237
|
+
completed_at=datetime.now(timezone.utc),
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
async def mark_sync_failed(run_id: str, error_message: str) -> Optional[SyncOperation]:
|
|
242
|
+
"""Convenience method to mark a sync operation as failed."""
|
|
243
|
+
return await update_sync_operation(
|
|
244
|
+
run_id=run_id,
|
|
245
|
+
status=SyncStatus.FAILED,
|
|
246
|
+
error_message=error_message,
|
|
247
|
+
completed_at=datetime.now(timezone.utc),
|
|
248
|
+
)
|