fastapi-offline-sync 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastapi_offline_sync/__init__.py +5 -0
- fastapi_offline_sync/cli.py +26 -0
- fastapi_offline_sync/config.py +90 -0
- fastapi_offline_sync/exceptions.py +29 -0
- fastapi_offline_sync/hlc.py +92 -0
- fastapi_offline_sync/metrics.py +39 -0
- fastapi_offline_sync/mongo.py +55 -0
- fastapi_offline_sync/resolver.py +64 -0
- fastapi_offline_sync/router.py +138 -0
- fastapi_offline_sync/schemas/__init__.py +20 -0
- fastapi_offline_sync/schemas/common.py +42 -0
- fastapi_offline_sync/schemas/full.py +22 -0
- fastapi_offline_sync/schemas/pull.py +34 -0
- fastapi_offline_sync/schemas/push.py +36 -0
- fastapi_offline_sync/schemas/stream.py +30 -0
- fastapi_offline_sync/service.py +562 -0
- fastapi_offline_sync-0.1.1.dist-info/METADATA +206 -0
- fastapi_offline_sync-0.1.1.dist-info/RECORD +20 -0
- fastapi_offline_sync-0.1.1.dist-info/WHEEL +4 -0
- fastapi_offline_sync-0.1.1.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import asyncio
|
|
5
|
+
|
|
6
|
+
from fastapi_offline_sync.config import SyncConfig
|
|
7
|
+
from fastapi_offline_sync.service import SyncService
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
async def _init_command(args: argparse.Namespace) -> None:
|
|
11
|
+
config = SyncConfig(mongodb_uri=args.mongodb_uri, database_name=args.database_name)
|
|
12
|
+
service = SyncService(config)
|
|
13
|
+
await service.initialize()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def main() -> None:
|
|
17
|
+
parser = argparse.ArgumentParser(prog="fastapi-offline-sync")
|
|
18
|
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
19
|
+
|
|
20
|
+
init_parser = subparsers.add_parser("init")
|
|
21
|
+
init_parser.add_argument("--mongodb-uri", required=True)
|
|
22
|
+
init_parser.add_argument("--database-name", required=True)
|
|
23
|
+
|
|
24
|
+
args = parser.parse_args()
|
|
25
|
+
if args.command == "init":
|
|
26
|
+
asyncio.run(_init_command(args))
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pydantic import Field
|
|
8
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
AuthorizationCallback = Callable[[Any, str], bool | Awaitable[bool]]
|
|
12
|
+
JWTDependency = Callable[..., Any]
|
|
13
|
+
WebSocketIdentityResolver = Callable[..., Any]
|
|
14
|
+
ConflictResolver = Callable[[str, Any, dict[str, Any] | None, dict[str, Any], str | None, str], dict[str, Any] | None]
|
|
15
|
+
IdentityResolver = Callable[[Any], str | None]
|
|
16
|
+
IncomingDocumentTransform = Callable[[str, dict[str, Any] | None, Any], dict[str, Any] | None | Awaitable[dict[str, Any] | None]]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(slots=True)
|
|
20
|
+
class CollectionPolicy:
|
|
21
|
+
delete_update_policy: str = "resurrect_on_update"
|
|
22
|
+
|
|
23
|
+
def __post_init__(self) -> None:
|
|
24
|
+
if self.delete_update_policy not in {"resurrect_on_update", "reject_update", "delete_wins"}:
|
|
25
|
+
raise ValueError(f"Unsupported delete_update_policy: {self.delete_update_policy}")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(slots=True)
|
|
29
|
+
class SyncCollectionConfig:
|
|
30
|
+
name: str
|
|
31
|
+
allow_full_resync: bool = True
|
|
32
|
+
policy: CollectionPolicy = field(default_factory=CollectionPolicy)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SyncConfig(BaseSettings):
|
|
36
|
+
mongodb_uri: str
|
|
37
|
+
database_name: str
|
|
38
|
+
sync_prefix: str = "/sync"
|
|
39
|
+
sync_oplog_collection: str = "sync_oplog"
|
|
40
|
+
oplog_ttl_days: int = 30
|
|
41
|
+
default_pull_limit: int = 500
|
|
42
|
+
max_pull_limit: int = 1000
|
|
43
|
+
default_full_limit: int = 1000
|
|
44
|
+
collections: Sequence[str] = Field(default_factory=tuple)
|
|
45
|
+
collection_configs: Mapping[str, SyncCollectionConfig] = Field(default_factory=dict)
|
|
46
|
+
conflict_resolvers: Mapping[str, ConflictResolver] = Field(default_factory=dict)
|
|
47
|
+
jwt_dependency: JWTDependency | None = None
|
|
48
|
+
websocket_identity_resolver: WebSocketIdentityResolver | None = None
|
|
49
|
+
authorization_callback: AuthorizationCallback | None = None
|
|
50
|
+
incoming_document_transform: IncomingDocumentTransform | None = None
|
|
51
|
+
actor_id_field: str = "user_id"
|
|
52
|
+
scope_id_field: str = "user_id"
|
|
53
|
+
actor_id_resolver: IdentityResolver | None = None
|
|
54
|
+
scope_id_resolver: IdentityResolver | None = None
|
|
55
|
+
batch_atomic: bool = True
|
|
56
|
+
use_transactions: bool = True
|
|
57
|
+
metrics_enabled: bool = True
|
|
58
|
+
hlc_node_id: str | None = None
|
|
59
|
+
"""Optional node identifier appended to HLC timestamps (e.g. ``"a3f2"``).
|
|
60
|
+
|
|
61
|
+
Set this to any short unique string per deployment unit (e.g. last 4 hex
|
|
62
|
+
of the pod IP, a worker index, or a UUID prefix). When ``None`` the
|
|
63
|
+
service will auto-derive a node ID from the local MAC address, which is
|
|
64
|
+
sufficient for single-host setups.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
model_config = SettingsConfigDict(
|
|
68
|
+
env_prefix="FASTAPI_OFFLINE_SYNC_",
|
|
69
|
+
arbitrary_types_allowed=True,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def configured_collections(self) -> tuple[str, ...]:
|
|
73
|
+
if self.collections:
|
|
74
|
+
return tuple(self.collections)
|
|
75
|
+
return tuple(self.collection_configs.keys())
|
|
76
|
+
|
|
77
|
+
def delete_update_policy_for(self, collection: str) -> str:
|
|
78
|
+
collection_config = self.collection_configs.get(collection)
|
|
79
|
+
if collection_config is None:
|
|
80
|
+
return "resurrect_on_update"
|
|
81
|
+
return collection_config.policy.delete_update_policy
|
|
82
|
+
|
|
83
|
+
def full_resync_allowed_for(self, collection: str) -> bool:
|
|
84
|
+
collection_config = self.collection_configs.get(collection)
|
|
85
|
+
if collection_config is None:
|
|
86
|
+
return True
|
|
87
|
+
return collection_config.allow_full_resync
|
|
88
|
+
|
|
89
|
+
def resolver_for(self, collection: str) -> ConflictResolver | None:
|
|
90
|
+
return self.conflict_resolvers.get(collection)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
class SyncError(Exception):
|
|
2
|
+
"""Base exception for sync engine failures."""
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ConflictRejection(SyncError):
|
|
6
|
+
"""Raised when a conflict resolver permanently rejects a change."""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FullResyncRequired(SyncError):
|
|
10
|
+
"""Raised when the client's sequence is older than oplog retention."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, collections: list[str]) -> None:
|
|
13
|
+
super().__init__("Full resync required")
|
|
14
|
+
self.collections = collections
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AuthorizationDenied(SyncError):
|
|
18
|
+
"""Raised when a user is not authorized for a collection."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TransactionUnavailable(SyncError):
|
|
22
|
+
"""Raised when a MongoDB transaction cannot be started.
|
|
23
|
+
|
|
24
|
+
This typically means the server is not part of a replica set.
|
|
25
|
+
Transactions are required for atomic push operations; falling back to
|
|
26
|
+
non-transactional writes risks partial commits that corrupt sync state.
|
|
27
|
+
Set ``use_transactions=False`` in SyncConfig only if you explicitly accept
|
|
28
|
+
best-effort consistency (not recommended for production).
|
|
29
|
+
"""
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import uuid
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from threading import Lock
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _derive_node_id() -> str:
|
|
10
|
+
"""Return a 4-character hex string derived from this machine's MAC address.
|
|
11
|
+
|
|
12
|
+
Uses Python's ``uuid.getnode()`` which returns the MAC as an integer.
|
|
13
|
+
Falls back gracefully if the MAC cannot be determined (uuid will make one
|
|
14
|
+
up, which is still unique enough for our purposes).
|
|
15
|
+
"""
|
|
16
|
+
mac = uuid.getnode()
|
|
17
|
+
return format(mac & 0xFFFF, "04x")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def validate_hlc(value: str) -> str:
|
|
21
|
+
"""Validate an HLC string, accepting both 2-part and 3-part formats.
|
|
22
|
+
|
|
23
|
+
2-part (legacy): ``20260502T143000.000Z-0001``
|
|
24
|
+
3-part (node-aware): ``20260502T143000.000Z-0001-a3f2``
|
|
25
|
+
"""
|
|
26
|
+
try:
|
|
27
|
+
parts = value.split("-", maxsplit=2)
|
|
28
|
+
if len(parts) not in (2, 3):
|
|
29
|
+
raise ValueError("wrong number of parts")
|
|
30
|
+
timestamp_part = parts[0]
|
|
31
|
+
counter_part = parts[1]
|
|
32
|
+
datetime.strptime(timestamp_part, "%Y%m%dT%H%M%S.%fZ")
|
|
33
|
+
counter = int(counter_part)
|
|
34
|
+
except ValueError as exc:
|
|
35
|
+
raise ValueError(f"Invalid HLC value: {value}") from exc
|
|
36
|
+
|
|
37
|
+
if counter < 0:
|
|
38
|
+
raise ValueError(f"Invalid HLC counter: {value}")
|
|
39
|
+
return value
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(slots=True)
|
|
43
|
+
class HLCState:
|
|
44
|
+
last_timestamp: datetime = field(default_factory=lambda: datetime.fromtimestamp(0, tz=timezone.utc))
|
|
45
|
+
counter: int = 0
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class HybridLogicalClock:
|
|
49
|
+
def __init__(self, node_id: str | None = None) -> None:
|
|
50
|
+
self._lock = Lock()
|
|
51
|
+
self._state = HLCState()
|
|
52
|
+
self._node_id = node_id
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def derive_node_id() -> str:
|
|
56
|
+
"""Return a 4-char hex node ID derived from the local MAC address.
|
|
57
|
+
|
|
58
|
+
Suitable for single-host multi-worker setups. For multi-host
|
|
59
|
+
deployments supply an explicit ``hlc_node_id`` in ``SyncConfig``
|
|
60
|
+
(e.g. last 4 hex of the pod IP or a short UUID prefix).
|
|
61
|
+
"""
|
|
62
|
+
return _derive_node_id()
|
|
63
|
+
|
|
64
|
+
def now(self) -> str:
|
|
65
|
+
with self._lock:
|
|
66
|
+
current = datetime.now(tz=timezone.utc)
|
|
67
|
+
if current > self._state.last_timestamp:
|
|
68
|
+
self._state.last_timestamp = current
|
|
69
|
+
self._state.counter = 0
|
|
70
|
+
else:
|
|
71
|
+
self._state.counter += 1
|
|
72
|
+
|
|
73
|
+
return self._format(self._state.last_timestamp, self._state.counter, self._node_id)
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def compare(left: str, right: str) -> int:
|
|
77
|
+
validate_hlc(left)
|
|
78
|
+
validate_hlc(right)
|
|
79
|
+
# Strip optional node_id suffix before comparing so that clocks from
|
|
80
|
+
# different nodes with the same timestamp+counter are considered equal.
|
|
81
|
+
left_core = "-".join(left.split("-", maxsplit=2)[:2])
|
|
82
|
+
right_core = "-".join(right.split("-", maxsplit=2)[:2])
|
|
83
|
+
if left_core == right_core:
|
|
84
|
+
return 0
|
|
85
|
+
return -1 if left_core < right_core else 1
|
|
86
|
+
|
|
87
|
+
@staticmethod
|
|
88
|
+
def _format(timestamp: datetime, counter: int, node_id: str | None) -> str:
|
|
89
|
+
base = f"{timestamp.strftime('%Y%m%dT%H%M%S.%f')[:-3]}Z-{counter:04d}"
|
|
90
|
+
if node_id:
|
|
91
|
+
return f"{base}-{node_id}"
|
|
92
|
+
return base
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Final
|
|
5
|
+
|
|
6
|
+
from prometheus_client import Counter, Gauge
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(slots=True)
|
|
10
|
+
class SyncMetrics:
|
|
11
|
+
push_total: Counter
|
|
12
|
+
push_changes_total: Counter
|
|
13
|
+
conflict_total: Counter
|
|
14
|
+
pull_total: Counter
|
|
15
|
+
pull_changes_served: Counter
|
|
16
|
+
full_resync_total: Counter
|
|
17
|
+
oplog_size: Gauge
|
|
18
|
+
websocket_connections: Gauge
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
_METRICS: SyncMetrics | None = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def build_metrics() -> SyncMetrics:
|
|
25
|
+
global _METRICS
|
|
26
|
+
|
|
27
|
+
if _METRICS is None:
|
|
28
|
+
_METRICS = SyncMetrics(
|
|
29
|
+
push_total=Counter("sync_push_total", "Push requests", labelnames=("status",)),
|
|
30
|
+
push_changes_total=Counter("sync_push_changes_total", "Changes pushed"),
|
|
31
|
+
conflict_total=Counter("sync_conflict_total", "Conflicts detected", labelnames=("collection",)),
|
|
32
|
+
pull_total=Counter("sync_pull_total", "Pull requests"),
|
|
33
|
+
pull_changes_served=Counter("sync_pull_changes_served", "Changes served from pull"),
|
|
34
|
+
full_resync_total=Counter("sync_full_resync_total", "Full resync requests"),
|
|
35
|
+
oplog_size=Gauge("sync_oplog_size", "Estimated oplog document count"),
|
|
36
|
+
websocket_connections=Gauge("sync_websocket_connections", "Active websocket connections"),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
return _METRICS
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorCollection, AsyncIOMotorDatabase
|
|
7
|
+
from pymongo import ASCENDING, DESCENDING
|
|
8
|
+
from pymongo.errors import CollectionInvalid
|
|
9
|
+
|
|
10
|
+
from fastapi_offline_sync.config import SyncConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MongoManager:
|
|
14
|
+
def __init__(self, config: SyncConfig) -> None:
|
|
15
|
+
self._config = config
|
|
16
|
+
self._client = AsyncIOMotorClient(config.mongodb_uri)
|
|
17
|
+
self._database = self._client[config.database_name]
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def database(self) -> AsyncIOMotorDatabase:
|
|
21
|
+
return self._database
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def oplog(self) -> AsyncIOMotorCollection:
|
|
25
|
+
return self._database[self._config.sync_oplog_collection]
|
|
26
|
+
|
|
27
|
+
def collection(self, name: str) -> AsyncIOMotorCollection:
|
|
28
|
+
return self._database[name]
|
|
29
|
+
|
|
30
|
+
async def initialize(self) -> None:
|
|
31
|
+
try:
|
|
32
|
+
await self._database.create_collection(self._config.sync_oplog_collection)
|
|
33
|
+
except CollectionInvalid:
|
|
34
|
+
pass
|
|
35
|
+
await self.oplog.create_index([("timestamp", ASCENDING)], expireAfterSeconds=self._config.oplog_ttl_days * 86400)
|
|
36
|
+
await self.oplog.create_index([("scope_id", ASCENDING), ("collection", ASCENDING), ("_id", ASCENDING)])
|
|
37
|
+
for collection_name in self._config.configured_collections():
|
|
38
|
+
await self.ensure_business_indexes(collection_name)
|
|
39
|
+
|
|
40
|
+
async def ensure_business_indexes(self, collection_name: str) -> None:
|
|
41
|
+
collection = self.collection(collection_name)
|
|
42
|
+
await collection.create_index([("_sync_scope_id", ASCENDING), ("_sync_deleted", ASCENDING), ("_id", ASCENDING)])
|
|
43
|
+
await collection.create_index([("_sync_version", DESCENDING)])
|
|
44
|
+
|
|
45
|
+
async def estimate_oplog_size(self) -> int:
|
|
46
|
+
return await self.oplog.estimated_document_count()
|
|
47
|
+
|
|
48
|
+
async def transaction(self):
|
|
49
|
+
session = await self._client.start_session()
|
|
50
|
+
return session
|
|
51
|
+
|
|
52
|
+
@asynccontextmanager
|
|
53
|
+
async def watch_oplog(self, pipeline: list[dict[str, Any]], *, full_document: str = "updateLookup"):
|
|
54
|
+
async with self.oplog.watch(pipeline=pipeline, full_document=full_document) as stream:
|
|
55
|
+
yield stream
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from fastapi_offline_sync.exceptions import ConflictRejection
|
|
7
|
+
from fastapi_offline_sync.hlc import HybridLogicalClock
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(slots=True)
|
|
11
|
+
class ResolutionResult:
|
|
12
|
+
document: dict[str, Any] | None
|
|
13
|
+
operation: str
|
|
14
|
+
status: str
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LastWriterWinsByHLC:
|
|
18
|
+
def resolve(
|
|
19
|
+
self,
|
|
20
|
+
*,
|
|
21
|
+
collection: str,
|
|
22
|
+
doc_id: str | int,
|
|
23
|
+
current_doc: dict[str, Any] | None,
|
|
24
|
+
incoming_change: dict[str, Any],
|
|
25
|
+
current_version: str | None,
|
|
26
|
+
incoming_parent_version: str,
|
|
27
|
+
delete_update_policy: str,
|
|
28
|
+
) -> ResolutionResult:
|
|
29
|
+
del collection
|
|
30
|
+
|
|
31
|
+
if current_doc is None or current_version is None:
|
|
32
|
+
return ResolutionResult(
|
|
33
|
+
document=incoming_change.get("data"),
|
|
34
|
+
operation=incoming_change["operation"],
|
|
35
|
+
status="accepted",
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if incoming_change["operation"] == "upsert" and current_doc.get("_sync_deleted"):
|
|
39
|
+
if delete_update_policy == "reject_update":
|
|
40
|
+
raise ConflictRejection(f"Update rejected for deleted document: {doc_id}")
|
|
41
|
+
if delete_update_policy == "delete_wins":
|
|
42
|
+
return ResolutionResult(document=None, operation="delete", status="rejected")
|
|
43
|
+
|
|
44
|
+
comparison = HybridLogicalClock.compare(incoming_parent_version, current_version)
|
|
45
|
+
if comparison >= 0:
|
|
46
|
+
return ResolutionResult(
|
|
47
|
+
document=incoming_change.get("data"),
|
|
48
|
+
operation=incoming_change["operation"],
|
|
49
|
+
status="conflict_resolved",
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
tie_break = str(doc_id)
|
|
53
|
+
if incoming_parent_version == current_version and tie_break <= str(doc_id):
|
|
54
|
+
return ResolutionResult(
|
|
55
|
+
document=incoming_change.get("data"),
|
|
56
|
+
operation=incoming_change["operation"],
|
|
57
|
+
status="conflict_resolved",
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return ResolutionResult(
|
|
61
|
+
document=current_doc,
|
|
62
|
+
operation="upsert" if not current_doc.get("_sync_deleted") else "delete",
|
|
63
|
+
status="rejected",
|
|
64
|
+
)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import contextlib
|
|
5
|
+
from typing import Any, Callable, cast
|
|
6
|
+
|
|
7
|
+
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect, status
|
|
8
|
+
from fastapi.responses import StreamingResponse
|
|
9
|
+
from prometheus_fastapi_instrumentator import Instrumentator
|
|
10
|
+
|
|
11
|
+
from fastapi_offline_sync.config import SyncConfig
|
|
12
|
+
from fastapi_offline_sync.exceptions import AuthorizationDenied, FullResyncRequired
|
|
13
|
+
from fastapi_offline_sync.schemas import FullResyncRequiredResponse, FullSyncQuery, PullQuery, PullResponse, PushRequest, PushResponse, StreamSubscribeMessage
|
|
14
|
+
from fastapi_offline_sync.service import SyncService
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _noop_user(*, token: str | None = None) -> dict[str, str]:
|
|
18
|
+
del token
|
|
19
|
+
return {"user_id": "anonymous"}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SyncRouter(APIRouter):
|
|
23
|
+
def __init__(self, config: SyncConfig) -> None:
|
|
24
|
+
self.config = config
|
|
25
|
+
self.service = SyncService(config)
|
|
26
|
+
dependency = config.jwt_dependency or _noop_user
|
|
27
|
+
self._auth_dependency = dependency
|
|
28
|
+
self._websocket_auth_dependency: Callable[..., Any] = cast(
|
|
29
|
+
Callable[..., Any],
|
|
30
|
+
config.websocket_identity_resolver or dependency,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
super().__init__(prefix=config.sync_prefix, tags=["sync"])
|
|
34
|
+
self.add_event_handler("startup", self.install)
|
|
35
|
+
|
|
36
|
+
async def push_endpoint(payload: PushRequest, user: Any = Depends(dependency)) -> PushResponse:
|
|
37
|
+
return await self._handle_push(payload, user)
|
|
38
|
+
|
|
39
|
+
async def pull_endpoint(
|
|
40
|
+
since: str,
|
|
41
|
+
collections: str | None = None,
|
|
42
|
+
limit: int = Query(default=config.default_pull_limit, ge=1, le=config.max_pull_limit),
|
|
43
|
+
user: Any = Depends(dependency),
|
|
44
|
+
) -> PullResponse | FullResyncRequiredResponse:
|
|
45
|
+
return await self._handle_pull(since=since, collections=collections, limit=limit, user=user)
|
|
46
|
+
|
|
47
|
+
async def full_sync_endpoint(
|
|
48
|
+
collections: str,
|
|
49
|
+
cursor: str | None = None,
|
|
50
|
+
limit: int = Query(default=config.default_full_limit, ge=1, le=config.max_pull_limit),
|
|
51
|
+
user: Any = Depends(dependency),
|
|
52
|
+
) -> StreamingResponse:
|
|
53
|
+
return await self._handle_full_sync(collections=collections, cursor=cursor, limit=limit, user=user)
|
|
54
|
+
|
|
55
|
+
self.add_api_route("/push", push_endpoint, methods=["POST"], response_model=PushResponse)
|
|
56
|
+
self.add_api_route("/pull", pull_endpoint, methods=["GET"], response_model=PullResponse | FullResyncRequiredResponse)
|
|
57
|
+
self.add_api_route("/full", full_sync_endpoint, methods=["GET"])
|
|
58
|
+
self.add_api_websocket_route("/stream", self.stream)
|
|
59
|
+
|
|
60
|
+
async def install(self) -> None:
|
|
61
|
+
await self.service.initialize()
|
|
62
|
+
|
|
63
|
+
def instrument(self, app: Any) -> None:
|
|
64
|
+
if not self.config.metrics_enabled:
|
|
65
|
+
return
|
|
66
|
+
Instrumentator().instrument(app).expose(app)
|
|
67
|
+
|
|
68
|
+
async def _handle_push(self, payload: PushRequest, user: Any) -> PushResponse:
|
|
69
|
+
try:
|
|
70
|
+
return await self.service.push(payload, user)
|
|
71
|
+
except AuthorizationDenied as exc:
|
|
72
|
+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(exc)) from exc
|
|
73
|
+
|
|
74
|
+
async def _handle_pull(self, since: str, collections: str | None, limit: int, user: Any) -> PullResponse | FullResyncRequiredResponse:
|
|
75
|
+
parsed = PullQuery(since=since, collections=collections.split(",") if collections else [], limit=limit)
|
|
76
|
+
try:
|
|
77
|
+
return await self.service.pull(since=parsed.since, collections=parsed.collections, limit=parsed.limit, user=user)
|
|
78
|
+
except FullResyncRequired as exc:
|
|
79
|
+
return FullResyncRequiredResponse(collections=exc.collections)
|
|
80
|
+
except AuthorizationDenied as exc:
|
|
81
|
+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(exc)) from exc
|
|
82
|
+
|
|
83
|
+
async def _handle_full_sync(self, collections: str, cursor: str | None, limit: int, user: Any) -> StreamingResponse:
|
|
84
|
+
query = FullSyncQuery(collections=collections.split(","), cursor=cursor, limit=limit)
|
|
85
|
+
try:
|
|
86
|
+
body, headers = await self.service.full_sync(query, user)
|
|
87
|
+
except AuthorizationDenied as exc:
|
|
88
|
+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(exc)) from exc
|
|
89
|
+
return StreamingResponse(body, media_type="application/x-ndjson", headers=headers)
|
|
90
|
+
|
|
91
|
+
async def stream(self, websocket: WebSocket) -> None:
|
|
92
|
+
subscription_task: asyncio.Task[None] | None = None
|
|
93
|
+
heartbeat_task: asyncio.Task[None] | None = None
|
|
94
|
+
try:
|
|
95
|
+
await websocket.accept()
|
|
96
|
+
message = await websocket.receive_json()
|
|
97
|
+
token = message.get("token") or websocket.query_params.get("token")
|
|
98
|
+
scope_id = message.get("scope_id") or websocket.query_params.get("scope_id")
|
|
99
|
+
user = await self._resolve_websocket_user(token, scope_id)
|
|
100
|
+
payload = dict(message)
|
|
101
|
+
payload.pop("token", None)
|
|
102
|
+
subscription = StreamSubscribeMessage.model_validate(payload)
|
|
103
|
+
|
|
104
|
+
async def forward_events() -> None:
|
|
105
|
+
async for event in self.service.stream_changes(subscription, user):
|
|
106
|
+
await websocket.send_json(event)
|
|
107
|
+
|
|
108
|
+
async def forward_heartbeat() -> None:
|
|
109
|
+
async for ping in self.service.heartbeat():
|
|
110
|
+
await websocket.send_json(ping)
|
|
111
|
+
|
|
112
|
+
subscription_task = asyncio.create_task(forward_events())
|
|
113
|
+
heartbeat_task = asyncio.create_task(forward_heartbeat())
|
|
114
|
+
done, pending = await asyncio.wait({subscription_task, heartbeat_task}, return_when=asyncio.FIRST_COMPLETED)
|
|
115
|
+
for task in done:
|
|
116
|
+
await task
|
|
117
|
+
for task in pending:
|
|
118
|
+
task.cancel()
|
|
119
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
120
|
+
await task
|
|
121
|
+
except (WebSocketDisconnect, AuthorizationDenied):
|
|
122
|
+
return
|
|
123
|
+
finally:
|
|
124
|
+
if subscription_task is not None:
|
|
125
|
+
subscription_task.cancel()
|
|
126
|
+
if heartbeat_task is not None:
|
|
127
|
+
heartbeat_task.cancel()
|
|
128
|
+
with contextlib.suppress(RuntimeError):
|
|
129
|
+
await websocket.close()
|
|
130
|
+
|
|
131
|
+
async def _resolve_websocket_user(self, token: str | None, scope_id: str | None = None) -> Any:
|
|
132
|
+
if token is not None or scope_id is not None:
|
|
133
|
+
resolved = self._websocket_auth_dependency(token=token, scope_id=scope_id)
|
|
134
|
+
else:
|
|
135
|
+
resolved = self._websocket_auth_dependency()
|
|
136
|
+
if asyncio.iscoroutine(resolved):
|
|
137
|
+
return await resolved
|
|
138
|
+
return resolved
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from .common import OplogEntry, SyncChangePayload
|
|
2
|
+
from .full import FullSyncQuery, FullSyncRecord
|
|
3
|
+
from .pull import FullResyncRequiredResponse, PullQuery, PullResponse
|
|
4
|
+
from .push import PushRequest, PushResponse, PushResult
|
|
5
|
+
from .stream import StreamControlMessage, StreamSubscribeMessage
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"FullResyncRequiredResponse",
|
|
9
|
+
"FullSyncQuery",
|
|
10
|
+
"FullSyncRecord",
|
|
11
|
+
"OplogEntry",
|
|
12
|
+
"PullQuery",
|
|
13
|
+
"PullResponse",
|
|
14
|
+
"PushRequest",
|
|
15
|
+
"PushResponse",
|
|
16
|
+
"PushResult",
|
|
17
|
+
"StreamControlMessage",
|
|
18
|
+
"StreamSubscribeMessage",
|
|
19
|
+
"SyncChangePayload",
|
|
20
|
+
]
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
6
|
+
|
|
7
|
+
from fastapi_offline_sync.hlc import validate_hlc
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SyncModel(BaseModel):
|
|
11
|
+
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SyncChangePayload(SyncModel):
|
|
15
|
+
collection: str
|
|
16
|
+
operation: str
|
|
17
|
+
doc_id: str | int = Field(serialization_alias="doc_id")
|
|
18
|
+
data: dict[str, Any] | None = None
|
|
19
|
+
parent_version: str | None = None
|
|
20
|
+
|
|
21
|
+
@field_validator("operation")
|
|
22
|
+
@classmethod
|
|
23
|
+
def validate_operation(cls, value: str) -> str:
|
|
24
|
+
if value not in {"upsert", "delete"}:
|
|
25
|
+
raise ValueError("operation must be 'upsert' or 'delete'")
|
|
26
|
+
return value
|
|
27
|
+
|
|
28
|
+
@field_validator("parent_version")
|
|
29
|
+
@classmethod
|
|
30
|
+
def validate_parent_version(cls, value: str | None) -> str | None:
|
|
31
|
+
if value in {None, ""}:
|
|
32
|
+
return None
|
|
33
|
+
return validate_hlc(value)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OplogEntry(SyncModel):
|
|
37
|
+
version: str
|
|
38
|
+
collection: str
|
|
39
|
+
doc_id: str | int
|
|
40
|
+
operation: str
|
|
41
|
+
data: dict[str, Any] | None
|
|
42
|
+
parent_version: str | None = None
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from fastapi_offline_sync.schemas.common import SyncModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class FullSyncQuery(SyncModel):
|
|
7
|
+
collections: list[str]
|
|
8
|
+
cursor: str | None = None
|
|
9
|
+
limit: int = 1000
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FullSyncRecord(SyncModel):
|
|
13
|
+
collection: str
|
|
14
|
+
doc_id: str | int
|
|
15
|
+
data: dict[str, object] | None
|
|
16
|
+
deleted: bool = False
|
|
17
|
+
"""True when this document was soft-deleted on the server.
|
|
18
|
+
|
|
19
|
+
Clients should remove the record from their local store when they see
|
|
20
|
+
``deleted=True``. ``data`` will be ``None`` for deleted records.
|
|
21
|
+
"""
|
|
22
|
+
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pydantic import field_validator
|
|
4
|
+
|
|
5
|
+
from fastapi_offline_sync.hlc import validate_hlc
|
|
6
|
+
from fastapi_offline_sync.schemas.common import OplogEntry, SyncModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PullQuery(SyncModel):
|
|
10
|
+
since: str
|
|
11
|
+
collections: list[str] = []
|
|
12
|
+
limit: int = 500
|
|
13
|
+
|
|
14
|
+
@field_validator("since")
|
|
15
|
+
@classmethod
|
|
16
|
+
def validate_since(cls, value: str) -> str:
|
|
17
|
+
return validate_hlc(value)
|
|
18
|
+
|
|
19
|
+
@field_validator("limit")
|
|
20
|
+
@classmethod
|
|
21
|
+
def validate_limit(cls, value: int) -> int:
|
|
22
|
+
if value < 1 or value > 1000:
|
|
23
|
+
raise ValueError("limit must be between 1 and 1000")
|
|
24
|
+
return value
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PullResponse(SyncModel):
|
|
28
|
+
changes: list[OplogEntry]
|
|
29
|
+
last_seq: str
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class FullResyncRequiredResponse(SyncModel):
|
|
33
|
+
full_resync_required: bool = True
|
|
34
|
+
collections: list[str]
|