cognee 0.2.4__py3-none-any.whl → 0.3.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. cognee/__init__.py +1 -0
  2. cognee/api/client.py +28 -3
  3. cognee/api/health.py +10 -13
  4. cognee/api/v1/add/add.py +3 -1
  5. cognee/api/v1/add/routers/get_add_router.py +12 -37
  6. cognee/api/v1/cloud/routers/__init__.py +1 -0
  7. cognee/api/v1/cloud/routers/get_checks_router.py +23 -0
  8. cognee/api/v1/cognify/code_graph_pipeline.py +9 -4
  9. cognee/api/v1/cognify/cognify.py +50 -3
  10. cognee/api/v1/cognify/routers/get_cognify_router.py +1 -1
  11. cognee/api/v1/datasets/routers/get_datasets_router.py +15 -4
  12. cognee/api/v1/memify/__init__.py +0 -0
  13. cognee/api/v1/memify/routers/__init__.py +1 -0
  14. cognee/api/v1/memify/routers/get_memify_router.py +100 -0
  15. cognee/api/v1/notebooks/routers/__init__.py +1 -0
  16. cognee/api/v1/notebooks/routers/get_notebooks_router.py +96 -0
  17. cognee/api/v1/search/routers/get_search_router.py +20 -1
  18. cognee/api/v1/search/search.py +11 -4
  19. cognee/api/v1/sync/__init__.py +17 -0
  20. cognee/api/v1/sync/routers/__init__.py +3 -0
  21. cognee/api/v1/sync/routers/get_sync_router.py +241 -0
  22. cognee/api/v1/sync/sync.py +877 -0
  23. cognee/api/v1/users/routers/get_auth_router.py +13 -1
  24. cognee/base_config.py +10 -1
  25. cognee/infrastructure/databases/graph/config.py +10 -4
  26. cognee/infrastructure/databases/graph/kuzu/adapter.py +135 -0
  27. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +89 -0
  28. cognee/infrastructure/databases/relational/__init__.py +2 -0
  29. cognee/infrastructure/databases/relational/get_async_session.py +15 -0
  30. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +6 -1
  31. cognee/infrastructure/databases/relational/with_async_session.py +25 -0
  32. cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +1 -1
  33. cognee/infrastructure/databases/vector/config.py +13 -6
  34. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +1 -1
  35. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +2 -6
  36. cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +4 -1
  37. cognee/infrastructure/files/storage/LocalFileStorage.py +9 -0
  38. cognee/infrastructure/files/storage/S3FileStorage.py +5 -0
  39. cognee/infrastructure/files/storage/StorageManager.py +7 -1
  40. cognee/infrastructure/files/storage/storage.py +16 -0
  41. cognee/infrastructure/llm/LLMGateway.py +18 -0
  42. cognee/infrastructure/llm/config.py +4 -2
  43. cognee/infrastructure/llm/prompts/extract_query_time.txt +15 -0
  44. cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +25 -0
  45. cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +30 -0
  46. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +2 -0
  47. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_event_entities.py +44 -0
  48. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/__init__.py +1 -0
  49. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_event_graph.py +46 -0
  50. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -1
  51. cognee/infrastructure/utils/run_sync.py +8 -1
  52. cognee/modules/chunking/models/DocumentChunk.py +4 -3
  53. cognee/modules/cloud/exceptions/CloudApiKeyMissingError.py +15 -0
  54. cognee/modules/cloud/exceptions/CloudConnectionError.py +15 -0
  55. cognee/modules/cloud/exceptions/__init__.py +2 -0
  56. cognee/modules/cloud/operations/__init__.py +1 -0
  57. cognee/modules/cloud/operations/check_api_key.py +25 -0
  58. cognee/modules/data/deletion/prune_system.py +1 -1
  59. cognee/modules/data/methods/check_dataset_name.py +1 -1
  60. cognee/modules/data/methods/get_dataset_data.py +1 -1
  61. cognee/modules/data/methods/load_or_create_datasets.py +1 -1
  62. cognee/modules/engine/models/Event.py +16 -0
  63. cognee/modules/engine/models/Interval.py +8 -0
  64. cognee/modules/engine/models/Timestamp.py +13 -0
  65. cognee/modules/engine/models/__init__.py +3 -0
  66. cognee/modules/engine/utils/__init__.py +2 -0
  67. cognee/modules/engine/utils/generate_event_datapoint.py +46 -0
  68. cognee/modules/engine/utils/generate_timestamp_datapoint.py +51 -0
  69. cognee/modules/graph/cognee_graph/CogneeGraph.py +2 -2
  70. cognee/modules/graph/utils/__init__.py +1 -0
  71. cognee/modules/graph/utils/resolve_edges_to_text.py +71 -0
  72. cognee/modules/memify/__init__.py +1 -0
  73. cognee/modules/memify/memify.py +118 -0
  74. cognee/modules/notebooks/methods/__init__.py +5 -0
  75. cognee/modules/notebooks/methods/create_notebook.py +26 -0
  76. cognee/modules/notebooks/methods/delete_notebook.py +13 -0
  77. cognee/modules/notebooks/methods/get_notebook.py +21 -0
  78. cognee/modules/notebooks/methods/get_notebooks.py +18 -0
  79. cognee/modules/notebooks/methods/update_notebook.py +17 -0
  80. cognee/modules/notebooks/models/Notebook.py +53 -0
  81. cognee/modules/notebooks/models/__init__.py +1 -0
  82. cognee/modules/notebooks/operations/__init__.py +1 -0
  83. cognee/modules/notebooks/operations/run_in_local_sandbox.py +55 -0
  84. cognee/modules/pipelines/layers/reset_dataset_pipeline_run_status.py +19 -3
  85. cognee/modules/pipelines/operations/pipeline.py +1 -0
  86. cognee/modules/pipelines/operations/run_tasks.py +17 -41
  87. cognee/modules/retrieval/base_graph_retriever.py +18 -0
  88. cognee/modules/retrieval/base_retriever.py +1 -1
  89. cognee/modules/retrieval/code_retriever.py +8 -0
  90. cognee/modules/retrieval/coding_rules_retriever.py +31 -0
  91. cognee/modules/retrieval/completion_retriever.py +9 -3
  92. cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +1 -0
  93. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +23 -14
  94. cognee/modules/retrieval/graph_completion_cot_retriever.py +21 -11
  95. cognee/modules/retrieval/graph_completion_retriever.py +32 -65
  96. cognee/modules/retrieval/graph_summary_completion_retriever.py +3 -1
  97. cognee/modules/retrieval/insights_retriever.py +14 -3
  98. cognee/modules/retrieval/summaries_retriever.py +1 -1
  99. cognee/modules/retrieval/temporal_retriever.py +152 -0
  100. cognee/modules/retrieval/utils/brute_force_triplet_search.py +7 -32
  101. cognee/modules/retrieval/utils/completion.py +10 -3
  102. cognee/modules/search/methods/get_search_type_tools.py +168 -0
  103. cognee/modules/search/methods/no_access_control_search.py +47 -0
  104. cognee/modules/search/methods/search.py +219 -139
  105. cognee/modules/search/types/SearchResult.py +21 -0
  106. cognee/modules/search/types/SearchType.py +2 -0
  107. cognee/modules/search/types/__init__.py +1 -0
  108. cognee/modules/search/utils/__init__.py +2 -0
  109. cognee/modules/search/utils/prepare_search_result.py +41 -0
  110. cognee/modules/search/utils/transform_context_to_graph.py +38 -0
  111. cognee/modules/sync/__init__.py +1 -0
  112. cognee/modules/sync/methods/__init__.py +23 -0
  113. cognee/modules/sync/methods/create_sync_operation.py +53 -0
  114. cognee/modules/sync/methods/get_sync_operation.py +107 -0
  115. cognee/modules/sync/methods/update_sync_operation.py +248 -0
  116. cognee/modules/sync/models/SyncOperation.py +142 -0
  117. cognee/modules/sync/models/__init__.py +3 -0
  118. cognee/modules/users/__init__.py +0 -1
  119. cognee/modules/users/methods/__init__.py +4 -1
  120. cognee/modules/users/methods/create_user.py +26 -1
  121. cognee/modules/users/methods/get_authenticated_user.py +36 -42
  122. cognee/modules/users/methods/get_default_user.py +3 -1
  123. cognee/modules/users/permissions/methods/get_specific_user_permission_datasets.py +2 -1
  124. cognee/root_dir.py +19 -0
  125. cognee/shared/logging_utils.py +1 -1
  126. cognee/tasks/codingagents/__init__.py +0 -0
  127. cognee/tasks/codingagents/coding_rule_associations.py +127 -0
  128. cognee/tasks/ingestion/save_data_item_to_storage.py +23 -0
  129. cognee/tasks/memify/__init__.py +2 -0
  130. cognee/tasks/memify/extract_subgraph.py +7 -0
  131. cognee/tasks/memify/extract_subgraph_chunks.py +11 -0
  132. cognee/tasks/repo_processor/get_repo_file_dependencies.py +52 -27
  133. cognee/tasks/temporal_graph/__init__.py +1 -0
  134. cognee/tasks/temporal_graph/add_entities_to_event.py +85 -0
  135. cognee/tasks/temporal_graph/enrich_events.py +34 -0
  136. cognee/tasks/temporal_graph/extract_events_and_entities.py +32 -0
  137. cognee/tasks/temporal_graph/extract_knowledge_graph_from_events.py +41 -0
  138. cognee/tasks/temporal_graph/models.py +49 -0
  139. cognee/tests/test_kuzu.py +4 -4
  140. cognee/tests/test_neo4j.py +4 -4
  141. cognee/tests/test_permissions.py +3 -3
  142. cognee/tests/test_relational_db_migration.py +7 -5
  143. cognee/tests/test_search_db.py +18 -24
  144. cognee/tests/test_temporal_graph.py +167 -0
  145. cognee/tests/unit/api/__init__.py +1 -0
  146. cognee/tests/unit/api/test_conditional_authentication_endpoints.py +246 -0
  147. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +18 -2
  148. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +13 -16
  149. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +11 -16
  150. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +5 -4
  151. cognee/tests/unit/modules/retrieval/insights_retriever_test.py +4 -2
  152. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +18 -2
  153. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +225 -0
  154. cognee/tests/unit/modules/users/__init__.py +1 -0
  155. cognee/tests/unit/modules/users/test_conditional_authentication.py +277 -0
  156. cognee/tests/unit/processing/utils/utils_test.py +20 -1
  157. {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/METADATA +8 -6
  158. {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/RECORD +162 -89
  159. cognee/tests/unit/modules/search/search_methods_test.py +0 -225
  160. {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/WHEEL +0 -0
  161. {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/entry_points.txt +0 -0
  162. {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/licenses/LICENSE +0 -0
  163. {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -0,0 +1,21 @@
1
+ from uuid import UUID
2
+ from pydantic import BaseModel
3
+ from typing import Any, Dict, List, Optional
4
+
5
+
6
+ class SearchResultDataset(BaseModel):
7
+ id: UUID
8
+ name: str
9
+
10
+
11
+ class CombinedSearchResult(BaseModel):
12
+ result: Optional[Any]
13
+ context: Dict[str, Any]
14
+ graphs: Optional[Dict[str, Any]] = {}
15
+ datasets: Optional[List[SearchResultDataset]] = None
16
+
17
+
18
+ class SearchResult(BaseModel):
19
+ search_result: Any
20
+ dataset_id: Optional[UUID]
21
+ dataset_name: Optional[str]
@@ -15,3 +15,5 @@ class SearchType(Enum):
15
15
  GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
16
16
  FEELING_LUCKY = "FEELING_LUCKY"
17
17
  FEEDBACK = "FEEDBACK"
18
+ TEMPORAL = "TEMPORAL"
19
+ CODING_RULES = "CODING_RULES"
@@ -1 +1,2 @@
1
1
  from .SearchType import SearchType
2
+ from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult
@@ -0,0 +1,2 @@
1
+ from .prepare_search_result import prepare_search_result
2
+ from .transform_context_to_graph import transform_context_to_graph
@@ -0,0 +1,41 @@
1
+ from typing import List, cast
2
+
3
+ from cognee.modules.graph.utils import resolve_edges_to_text
4
+ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
5
+ from cognee.modules.search.utils.transform_context_to_graph import transform_context_to_graph
6
+
7
+
8
+ async def prepare_search_result(search_result):
9
+ result, context, datasets = search_result
10
+
11
+ graphs = None
12
+ result_graph = None
13
+ context_texts = {}
14
+
15
+ if isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge):
16
+ context_graph = transform_context_to_graph(context)
17
+
18
+ graphs = {
19
+ "*": context_graph,
20
+ }
21
+ context_texts = {
22
+ "*": await resolve_edges_to_text(context),
23
+ }
24
+ elif isinstance(context, str):
25
+ context_texts = {
26
+ "*": context,
27
+ }
28
+ elif isinstance(context, List) and len(context) > 0 and isinstance(context[0], str):
29
+ context_texts = {
30
+ "*": "\n".join(cast(List[str], context)),
31
+ }
32
+
33
+ if isinstance(result, List) and len(result) > 0 and isinstance(result[0], Edge):
34
+ result_graph = transform_context_to_graph(result)
35
+
36
+ return {
37
+ "result": result_graph or result,
38
+ "graphs": graphs,
39
+ "context": context_texts,
40
+ "datasets": datasets,
41
+ }
@@ -0,0 +1,38 @@
1
+ from typing import List
2
+
3
+ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
4
+
5
+
6
+ def transform_context_to_graph(context: List[Edge]):
7
+ nodes = {}
8
+ edges = {}
9
+
10
+ for triplet in context:
11
+ nodes[triplet.node1.id] = {
12
+ "id": triplet.node1.id,
13
+ "label": triplet.node1.attributes["name"]
14
+ if "name" in triplet.node1.attributes
15
+ else triplet.node1.id,
16
+ "type": triplet.node1.attributes["type"],
17
+ "attributes": triplet.node2.attributes,
18
+ }
19
+ nodes[triplet.node2.id] = {
20
+ "id": triplet.node2.id,
21
+ "label": triplet.node2.attributes["name"]
22
+ if "name" in triplet.node2.attributes
23
+ else triplet.node2.id,
24
+ "type": triplet.node2.attributes["type"],
25
+ "attributes": triplet.node2.attributes,
26
+ }
27
+ edges[
28
+ f"{triplet.node1.id}_{triplet.attributes['relationship_name']}_{triplet.node2.id}"
29
+ ] = {
30
+ "source": triplet.node1.id,
31
+ "target": triplet.node2.id,
32
+ "label": triplet.attributes["relationship_name"],
33
+ }
34
+
35
+ return {
36
+ "nodes": list(nodes.values()),
37
+ "edges": list(edges.values()),
38
+ }
@@ -0,0 +1 @@
1
+ # Sync module for tracking sync operations
@@ -0,0 +1,23 @@
1
+ from .create_sync_operation import create_sync_operation
2
+ from .get_sync_operation import (
3
+ get_sync_operation,
4
+ get_user_sync_operations,
5
+ get_running_sync_operations_for_user,
6
+ )
7
+ from .update_sync_operation import (
8
+ update_sync_operation,
9
+ mark_sync_started,
10
+ mark_sync_completed,
11
+ mark_sync_failed,
12
+ )
13
+
14
+ __all__ = [
15
+ "create_sync_operation",
16
+ "get_sync_operation",
17
+ "get_user_sync_operations",
18
+ "get_running_sync_operations_for_user",
19
+ "update_sync_operation",
20
+ "mark_sync_started",
21
+ "mark_sync_completed",
22
+ "mark_sync_failed",
23
+ ]
@@ -0,0 +1,53 @@
1
+ from uuid import UUID
2
+ from typing import Optional, List
3
+ from datetime import datetime, timezone
4
+ from cognee.modules.sync.models import SyncOperation, SyncStatus
5
+ from cognee.infrastructure.databases.relational import get_relational_engine
6
+
7
+
8
+ async def create_sync_operation(
9
+ run_id: str,
10
+ dataset_ids: List[UUID],
11
+ dataset_names: List[str],
12
+ user_id: UUID,
13
+ total_records_to_sync: Optional[int] = None,
14
+ total_records_to_download: Optional[int] = None,
15
+ total_records_to_upload: Optional[int] = None,
16
+ ) -> SyncOperation:
17
+ """
18
+ Create a new sync operation record in the database.
19
+
20
+ Args:
21
+ run_id: Unique public identifier for this sync operation
22
+ dataset_ids: List of dataset UUIDs being synced
23
+ dataset_names: List of dataset names being synced
24
+ user_id: UUID of the user who initiated the sync
25
+ total_records_to_sync: Total number of records to sync (if known)
26
+ total_records_to_download: Total number of records to download (if known)
27
+ total_records_to_upload: Total number of records to upload (if known)
28
+
29
+ Returns:
30
+ SyncOperation: The created sync operation record
31
+ """
32
+ db_engine = get_relational_engine()
33
+
34
+ sync_operation = SyncOperation(
35
+ run_id=run_id,
36
+ dataset_ids=[
37
+ str(uuid) for uuid in dataset_ids
38
+ ], # Convert UUIDs to strings for JSON storage
39
+ dataset_names=dataset_names,
40
+ user_id=user_id,
41
+ status=SyncStatus.STARTED,
42
+ total_records_to_sync=total_records_to_sync,
43
+ total_records_to_download=total_records_to_download,
44
+ total_records_to_upload=total_records_to_upload,
45
+ created_at=datetime.now(timezone.utc),
46
+ )
47
+
48
+ async with db_engine.get_async_session() as session:
49
+ session.add(sync_operation)
50
+ await session.commit()
51
+ await session.refresh(sync_operation)
52
+
53
+ return sync_operation
@@ -0,0 +1,107 @@
1
+ from uuid import UUID
2
+ from typing import List, Optional
3
+ from sqlalchemy import select, desc, and_
4
+ from cognee.modules.sync.models import SyncOperation, SyncStatus
5
+ from cognee.infrastructure.databases.relational import get_relational_engine
6
+
7
+
8
+ async def get_sync_operation(run_id: str) -> Optional[SyncOperation]:
9
+ """
10
+ Get a sync operation by its run_id.
11
+
12
+ Args:
13
+ run_id: The public run_id of the sync operation
14
+
15
+ Returns:
16
+ SyncOperation: The sync operation record, or None if not found
17
+ """
18
+ db_engine = get_relational_engine()
19
+
20
+ async with db_engine.get_async_session() as session:
21
+ query = select(SyncOperation).where(SyncOperation.run_id == run_id)
22
+ result = await session.execute(query)
23
+ return result.scalars().first()
24
+
25
+
26
+ async def get_user_sync_operations(
27
+ user_id: UUID, limit: int = 50, offset: int = 0
28
+ ) -> List[SyncOperation]:
29
+ """
30
+ Get sync operations for a specific user, ordered by most recent first.
31
+
32
+ Args:
33
+ user_id: UUID of the user
34
+ limit: Maximum number of records to return
35
+ offset: Number of records to skip
36
+
37
+ Returns:
38
+ List[SyncOperation]: List of sync operations for the user
39
+ """
40
+ db_engine = get_relational_engine()
41
+
42
+ async with db_engine.get_async_session() as session:
43
+ query = (
44
+ select(SyncOperation)
45
+ .where(SyncOperation.user_id == user_id)
46
+ .order_by(desc(SyncOperation.created_at))
47
+ .limit(limit)
48
+ .offset(offset)
49
+ )
50
+ result = await session.execute(query)
51
+ return list(result.scalars().all())
52
+
53
+
54
+ async def get_sync_operations_by_dataset(
55
+ dataset_id: UUID, limit: int = 50, offset: int = 0
56
+ ) -> List[SyncOperation]:
57
+ """
58
+ Get sync operations for a specific dataset.
59
+
60
+ Args:
61
+ dataset_id: UUID of the dataset
62
+ limit: Maximum number of records to return
63
+ offset: Number of records to skip
64
+
65
+ Returns:
66
+ List[SyncOperation]: List of sync operations for the dataset
67
+ """
68
+ db_engine = get_relational_engine()
69
+
70
+ async with db_engine.get_async_session() as session:
71
+ query = (
72
+ select(SyncOperation)
73
+ .where(SyncOperation.dataset_id == dataset_id)
74
+ .order_by(desc(SyncOperation.created_at))
75
+ .limit(limit)
76
+ .offset(offset)
77
+ )
78
+ result = await session.execute(query)
79
+ return list(result.scalars().all())
80
+
81
+
82
+ async def get_running_sync_operations_for_user(user_id: UUID) -> List[SyncOperation]:
83
+ """
84
+ Get all currently running sync operations for a specific user.
85
+ Checks for operations with STARTED or IN_PROGRESS status.
86
+
87
+ Args:
88
+ user_id: UUID of the user
89
+
90
+ Returns:
91
+ List[SyncOperation]: List of running sync operations for the user
92
+ """
93
+ db_engine = get_relational_engine()
94
+
95
+ async with db_engine.get_async_session() as session:
96
+ query = (
97
+ select(SyncOperation)
98
+ .where(
99
+ and_(
100
+ SyncOperation.user_id == user_id,
101
+ SyncOperation.status.in_([SyncStatus.STARTED, SyncStatus.IN_PROGRESS]),
102
+ )
103
+ )
104
+ .order_by(desc(SyncOperation.created_at))
105
+ )
106
+ result = await session.execute(query)
107
+ return list(result.scalars().all())
@@ -0,0 +1,248 @@
1
+ import asyncio
2
+ from typing import Optional, List
3
+ from datetime import datetime, timezone
4
+ from sqlalchemy import select
5
+ from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError, TimeoutError
6
+ from cognee.modules.sync.models import SyncOperation, SyncStatus
7
+ from cognee.infrastructure.databases.relational import get_relational_engine
8
+ from cognee.shared.logging_utils import get_logger
9
+ from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
10
+
11
+ logger = get_logger("sync.db_operations")
12
+
13
+
14
+ async def _retry_db_operation(operation_func, run_id: str, max_retries: int = 3):
15
+ """
16
+ Retry database operations with exponential backoff for transient failures.
17
+
18
+ Args:
19
+ operation_func: Async function to retry
20
+ run_id: Run ID for logging context
21
+ max_retries: Maximum number of retry attempts
22
+
23
+ Returns:
24
+ Result of the operation function
25
+
26
+ Raises:
27
+ Exception: Re-raises the last exception if all retries fail
28
+ """
29
+ attempt = 0
30
+ last_exception = None
31
+
32
+ while attempt < max_retries:
33
+ try:
34
+ return await operation_func()
35
+ except (DisconnectionError, OperationalError, TimeoutError) as e:
36
+ attempt += 1
37
+ last_exception = e
38
+
39
+ if attempt >= max_retries:
40
+ logger.error(
41
+ f"Database operation failed after {max_retries} attempts for run_id {run_id}: {str(e)}"
42
+ )
43
+ break
44
+
45
+ backoff_time = calculate_backoff(attempt - 1) # calculate_backoff is 0-indexed
46
+ logger.warning(
47
+ f"Database operation failed for run_id {run_id}, retrying in {backoff_time:.2f}s (attempt {attempt}/{max_retries}): {str(e)}"
48
+ )
49
+ await asyncio.sleep(backoff_time)
50
+
51
+ except Exception as e:
52
+ # Non-transient errors should not be retried
53
+ logger.error(f"Non-retryable database error for run_id {run_id}: {str(e)}")
54
+ raise
55
+
56
+ # If we get here, all retries failed
57
+ raise last_exception
58
+
59
+
60
+ async def update_sync_operation(
61
+ run_id: str,
62
+ status: Optional[SyncStatus] = None,
63
+ progress_percentage: Optional[int] = None,
64
+ records_downloaded: Optional[int] = None,
65
+ total_records_to_sync: Optional[int] = None,
66
+ total_records_to_download: Optional[int] = None,
67
+ total_records_to_upload: Optional[int] = None,
68
+ records_uploaded: Optional[int] = None,
69
+ bytes_downloaded: Optional[int] = None,
70
+ bytes_uploaded: Optional[int] = None,
71
+ dataset_sync_hashes: Optional[dict] = None,
72
+ error_message: Optional[str] = None,
73
+ retry_count: Optional[int] = None,
74
+ started_at: Optional[datetime] = None,
75
+ completed_at: Optional[datetime] = None,
76
+ ) -> Optional[SyncOperation]:
77
+ """
78
+ Update a sync operation record with new status/progress information.
79
+
80
+ Args:
81
+ run_id: The public run_id of the sync operation to update
82
+ status: New status for the operation
83
+ progress_percentage: Progress percentage (0-100)
84
+ records_downloaded: Number of records downloaded so far
85
+ total_records_to_sync: Total number of records that need to be synced
86
+ total_records_to_download: Total number of records to download from cloud
87
+ total_records_to_upload: Total number of records to upload to cloud
88
+ records_uploaded: Number of records uploaded so far
89
+ bytes_downloaded: Total bytes downloaded from cloud
90
+ bytes_uploaded: Total bytes uploaded to cloud
91
+ dataset_sync_hashes: Dict mapping dataset_id -> {uploaded: [hashes], downloaded: [hashes]}
92
+ error_message: Error message if operation failed
93
+ retry_count: Number of retry attempts
94
+ started_at: When the actual processing started
95
+ completed_at: When the operation completed (success or failure)
96
+
97
+ Returns:
98
+ SyncOperation: The updated sync operation record, or None if not found
99
+ """
100
+
101
+ async def _perform_update():
102
+ db_engine = get_relational_engine()
103
+
104
+ async with db_engine.get_async_session() as session:
105
+ try:
106
+ # Find the sync operation
107
+ query = select(SyncOperation).where(SyncOperation.run_id == run_id)
108
+ result = await session.execute(query)
109
+ sync_operation = result.scalars().first()
110
+
111
+ if not sync_operation:
112
+ logger.warning(f"Sync operation not found for run_id: {run_id}")
113
+ return None
114
+
115
+ # Log what we're updating for debugging
116
+ updates = []
117
+ if status is not None:
118
+ updates.append(f"status={status.value}")
119
+ if progress_percentage is not None:
120
+ updates.append(f"progress={progress_percentage}%")
121
+ if records_downloaded is not None:
122
+ updates.append(f"downloaded={records_downloaded}")
123
+ if records_uploaded is not None:
124
+ updates.append(f"uploaded={records_uploaded}")
125
+ if total_records_to_sync is not None:
126
+ updates.append(f"total_sync={total_records_to_sync}")
127
+ if total_records_to_download is not None:
128
+ updates.append(f"total_download={total_records_to_download}")
129
+ if total_records_to_upload is not None:
130
+ updates.append(f"total_upload={total_records_to_upload}")
131
+
132
+ if updates:
133
+ logger.debug(f"Updating sync operation {run_id}: {', '.join(updates)}")
134
+
135
+ # Update fields that were provided
136
+ if status is not None:
137
+ sync_operation.status = status
138
+
139
+ if progress_percentage is not None:
140
+ sync_operation.progress_percentage = max(0, min(100, progress_percentage))
141
+
142
+ if records_downloaded is not None:
143
+ sync_operation.records_downloaded = records_downloaded
144
+
145
+ if records_uploaded is not None:
146
+ sync_operation.records_uploaded = records_uploaded
147
+
148
+ if total_records_to_sync is not None:
149
+ sync_operation.total_records_to_sync = total_records_to_sync
150
+
151
+ if total_records_to_download is not None:
152
+ sync_operation.total_records_to_download = total_records_to_download
153
+
154
+ if total_records_to_upload is not None:
155
+ sync_operation.total_records_to_upload = total_records_to_upload
156
+
157
+ if bytes_downloaded is not None:
158
+ sync_operation.bytes_downloaded = bytes_downloaded
159
+
160
+ if bytes_uploaded is not None:
161
+ sync_operation.bytes_uploaded = bytes_uploaded
162
+
163
+ if dataset_sync_hashes is not None:
164
+ sync_operation.dataset_sync_hashes = dataset_sync_hashes
165
+
166
+ if error_message is not None:
167
+ sync_operation.error_message = error_message
168
+
169
+ if retry_count is not None:
170
+ sync_operation.retry_count = retry_count
171
+
172
+ if started_at is not None:
173
+ sync_operation.started_at = started_at
174
+
175
+ if completed_at is not None:
176
+ sync_operation.completed_at = completed_at
177
+
178
+ # Auto-set completion timestamp for terminal statuses
179
+ if (
180
+ status in [SyncStatus.COMPLETED, SyncStatus.FAILED, SyncStatus.CANCELLED]
181
+ and completed_at is None
182
+ ):
183
+ sync_operation.completed_at = datetime.now(timezone.utc)
184
+
185
+ # Auto-set started timestamp when moving to IN_PROGRESS
186
+ if status == SyncStatus.IN_PROGRESS and sync_operation.started_at is None:
187
+ sync_operation.started_at = datetime.now(timezone.utc)
188
+
189
+ await session.commit()
190
+ await session.refresh(sync_operation)
191
+
192
+ logger.debug(f"Successfully updated sync operation {run_id}")
193
+ return sync_operation
194
+
195
+ except SQLAlchemyError as e:
196
+ logger.error(
197
+ f"Database error updating sync operation {run_id}: {str(e)}", exc_info=True
198
+ )
199
+ await session.rollback()
200
+ raise
201
+ except Exception as e:
202
+ logger.error(
203
+ f"Unexpected error updating sync operation {run_id}: {str(e)}", exc_info=True
204
+ )
205
+ await session.rollback()
206
+ raise
207
+
208
+ # Use retry logic for the database operation
209
+ return await _retry_db_operation(_perform_update, run_id)
210
+
211
+
212
+ async def mark_sync_started(run_id: str) -> Optional[SyncOperation]:
213
+ """Convenience method to mark a sync operation as started."""
214
+ return await update_sync_operation(
215
+ run_id=run_id, status=SyncStatus.IN_PROGRESS, started_at=datetime.now(timezone.utc)
216
+ )
217
+
218
+
219
+ async def mark_sync_completed(
220
+ run_id: str,
221
+ records_downloaded: int = 0,
222
+ records_uploaded: int = 0,
223
+ bytes_downloaded: int = 0,
224
+ bytes_uploaded: int = 0,
225
+ dataset_sync_hashes: Optional[dict] = None,
226
+ ) -> Optional[SyncOperation]:
227
+ """Convenience method to mark a sync operation as completed successfully."""
228
+ return await update_sync_operation(
229
+ run_id=run_id,
230
+ status=SyncStatus.COMPLETED,
231
+ progress_percentage=100,
232
+ records_downloaded=records_downloaded,
233
+ records_uploaded=records_uploaded,
234
+ bytes_downloaded=bytes_downloaded,
235
+ bytes_uploaded=bytes_uploaded,
236
+ dataset_sync_hashes=dataset_sync_hashes,
237
+ completed_at=datetime.now(timezone.utc),
238
+ )
239
+
240
+
241
+ async def mark_sync_failed(run_id: str, error_message: str) -> Optional[SyncOperation]:
242
+ """Convenience method to mark a sync operation as failed."""
243
+ return await update_sync_operation(
244
+ run_id=run_id,
245
+ status=SyncStatus.FAILED,
246
+ error_message=error_message,
247
+ completed_at=datetime.now(timezone.utc),
248
+ )