lab-link 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lab_link-0.1.1/PKG-INFO +45 -0
- lab_link-0.1.1/README.md +20 -0
- lab_link-0.1.1/pyproject.toml +39 -0
- lab_link-0.1.1/src/lab_link/__init__.py +5 -0
- lab_link-0.1.1/src/lab_link/connection_manager.py +112 -0
- lab_link-0.1.1/src/lab_link/core.py +547 -0
- lab_link-0.1.1/src/lab_link/errors.py +49 -0
- lab_link-0.1.1/src/lab_link/persistence.py +58 -0
- lab_link-0.1.1/src/lab_link/pointer.py +11 -0
- lab_link-0.1.1/src/lab_link/proxy.py +122 -0
- lab_link-0.1.1/src/lab_link/py.typed +0 -0
- lab_link-0.1.1/src/lab_link/state_store.py +153 -0
- lab_link-0.1.1/src/lab_link/stream_buffer.py +191 -0
lab_link-0.1.1/PKG-INFO
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: lab-link
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: Generic server-authoritative sync library for FastAPI
|
|
5
|
+
Author: Andrew Mueller
|
|
6
|
+
Author-email: Andrew Mueller <amueller@caltech.edu>
|
|
7
|
+
Requires-Dist: fastapi>=0.115
|
|
8
|
+
Requires-Dist: jsonpatch>=1.33
|
|
9
|
+
Requires-Dist: pydantic>=2.0
|
|
10
|
+
Requires-Dist: websockets>=12.0
|
|
11
|
+
Requires-Dist: pytest ; extra == 'dev'
|
|
12
|
+
Requires-Dist: pytest-asyncio ; extra == 'dev'
|
|
13
|
+
Requires-Dist: httpx ; extra == 'dev'
|
|
14
|
+
Requires-Dist: uvicorn[standard] ; extra == 'dev'
|
|
15
|
+
Requires-Dist: numpy>=1.24 ; extra == 'numpy'
|
|
16
|
+
Requires-Dist: sqlmodel>=0.0.24 ; extra == 'persist'
|
|
17
|
+
Requires-Python: >=3.11
|
|
18
|
+
Project-URL: Homepage, https://sansseriff.github.io/lab-link/
|
|
19
|
+
Project-URL: Repository, https://github.com/sansseriff/lab-link
|
|
20
|
+
Project-URL: Documentation, https://sansseriff.github.io/lab-link/
|
|
21
|
+
Provides-Extra: dev
|
|
22
|
+
Provides-Extra: numpy
|
|
23
|
+
Provides-Extra: persist
|
|
24
|
+
Description-Content-Type: text/markdown
|
|
25
|
+
|
|
26
|
+
# lab-link Python
|
|
27
|
+
|
|
28
|
+
FastAPI/Pydantic backend runtime for `lab-link`.
|
|
29
|
+
|
|
30
|
+
Use it to register authoritative state, expose a WebSocket sync endpoint, run
|
|
31
|
+
commands with hardware side effects, and broadcast versioned JSON Patch updates.
|
|
32
|
+
|
|
33
|
+
```bash
|
|
34
|
+
uv add lab-link
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
```python
|
|
38
|
+
from lab_link import LabSync
|
|
39
|
+
|
|
40
|
+
sync = LabSync()
|
|
41
|
+
sync.register_state(AppState, initial=state)
|
|
42
|
+
app = sync.create_app()
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
Full docs: https://sansseriff.github.io/lab-link/
|
lab_link-0.1.1/README.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# lab-link Python
|
|
2
|
+
|
|
3
|
+
FastAPI/Pydantic backend runtime for `lab-link`.
|
|
4
|
+
|
|
5
|
+
Use it to register authoritative state, expose a WebSocket sync endpoint, run
|
|
6
|
+
commands with hardware side effects, and broadcast versioned JSON Patch updates.
|
|
7
|
+
|
|
8
|
+
```bash
|
|
9
|
+
uv add lab-link
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
```python
|
|
13
|
+
from lab_link import LabSync
|
|
14
|
+
|
|
15
|
+
sync = LabSync()
|
|
16
|
+
sync.register_state(AppState, initial=state)
|
|
17
|
+
app = sync.create_app()
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
Full docs: https://sansseriff.github.io/lab-link/
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "lab-link"
|
|
3
|
+
version = "0.1.1"
|
|
4
|
+
description = "Generic server-authoritative sync library for FastAPI"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [
|
|
7
|
+
{ name = "Andrew Mueller", email = "amueller@caltech.edu" }
|
|
8
|
+
]
|
|
9
|
+
requires-python = ">=3.11"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"fastapi>=0.115",
|
|
12
|
+
"jsonpatch>=1.33",
|
|
13
|
+
"pydantic>=2.0",
|
|
14
|
+
"websockets>=12.0",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[project.urls]
|
|
18
|
+
Homepage = "https://sansseriff.github.io/lab-link/"
|
|
19
|
+
Repository = "https://github.com/sansseriff/lab-link"
|
|
20
|
+
Documentation = "https://sansseriff.github.io/lab-link/"
|
|
21
|
+
|
|
22
|
+
[project.optional-dependencies]
|
|
23
|
+
persist = [
|
|
24
|
+
"sqlmodel>=0.0.24",
|
|
25
|
+
]
|
|
26
|
+
numpy = ["numpy>=1.24"]
|
|
27
|
+
dev = ["pytest", "pytest-asyncio", "httpx", "uvicorn[standard]"]
|
|
28
|
+
|
|
29
|
+
[build-system]
|
|
30
|
+
requires = ["uv_build>=0.9.28,<0.10.0"]
|
|
31
|
+
build-backend = "uv_build"
|
|
32
|
+
|
|
33
|
+
[dependency-groups]
|
|
34
|
+
dev = [
|
|
35
|
+
"httpx>=0.28.1",
|
|
36
|
+
"pytest>=9.0.2",
|
|
37
|
+
"pytest-asyncio>=1.3.0",
|
|
38
|
+
"uvicorn[standard]>=0.42.0",
|
|
39
|
+
]
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from fastapi import WebSocket
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ConnectionManager:
|
|
11
|
+
def __init__(self) -> None:
|
|
12
|
+
self.connections: dict[str, WebSocket] = {}
|
|
13
|
+
self._lock = asyncio.Lock()
|
|
14
|
+
|
|
15
|
+
async def connect(
|
|
16
|
+
self,
|
|
17
|
+
websocket: WebSocket,
|
|
18
|
+
client_id: str,
|
|
19
|
+
state_snapshot: dict[str, Any],
|
|
20
|
+
state_version: int,
|
|
21
|
+
stream_snapshots: list[dict[str, Any]] | None = None,
|
|
22
|
+
) -> None:
|
|
23
|
+
await websocket.accept()
|
|
24
|
+
async with self._lock:
|
|
25
|
+
self.connections[client_id] = websocket
|
|
26
|
+
await websocket.send_json(
|
|
27
|
+
{"type": "snapshot", "data": state_snapshot, "version": state_version}
|
|
28
|
+
)
|
|
29
|
+
for msg in (stream_snapshots or []):
|
|
30
|
+
await websocket.send_json(msg)
|
|
31
|
+
|
|
32
|
+
async def disconnect(self, client_id: str) -> None:
|
|
33
|
+
async with self._lock:
|
|
34
|
+
self.connections.pop(client_id, None)
|
|
35
|
+
|
|
36
|
+
async def broadcast_patch(
|
|
37
|
+
self,
|
|
38
|
+
patch: list[dict[str, Any]],
|
|
39
|
+
version: int,
|
|
40
|
+
*,
|
|
41
|
+
origin_client_id: str | None = None,
|
|
42
|
+
request_id: str | None = None,
|
|
43
|
+
command: str | None = None,
|
|
44
|
+
) -> None:
|
|
45
|
+
message = {"type": "patch", "patch": patch, "version": version}
|
|
46
|
+
if origin_client_id is not None:
|
|
47
|
+
message["originClientId"] = origin_client_id
|
|
48
|
+
if request_id is not None:
|
|
49
|
+
message["requestId"] = request_id
|
|
50
|
+
if command is not None:
|
|
51
|
+
message["command"] = command
|
|
52
|
+
await self._broadcast_json(message)
|
|
53
|
+
|
|
54
|
+
async def broadcast_json(self, message: dict[str, Any]) -> None:
|
|
55
|
+
await self._broadcast_json(message)
|
|
56
|
+
|
|
57
|
+
async def broadcast_binary(self, frame: bytes) -> None:
|
|
58
|
+
async with self._lock:
|
|
59
|
+
items = list(self.connections.items())
|
|
60
|
+
|
|
61
|
+
disconnected: list[str] = []
|
|
62
|
+
for client_id, websocket in items:
|
|
63
|
+
try:
|
|
64
|
+
await websocket.send_bytes(frame)
|
|
65
|
+
except Exception:
|
|
66
|
+
disconnected.append(client_id)
|
|
67
|
+
|
|
68
|
+
if disconnected:
|
|
69
|
+
async with self._lock:
|
|
70
|
+
for cid in disconnected:
|
|
71
|
+
self.connections.pop(cid, None)
|
|
72
|
+
|
|
73
|
+
async def send_to(self, client_id: str, message: dict[str, Any]) -> None:
|
|
74
|
+
async with self._lock:
|
|
75
|
+
ws = self.connections.get(client_id)
|
|
76
|
+
if ws:
|
|
77
|
+
try:
|
|
78
|
+
await ws.send_json(message)
|
|
79
|
+
except Exception:
|
|
80
|
+
async with self._lock:
|
|
81
|
+
self.connections.pop(client_id, None)
|
|
82
|
+
|
|
83
|
+
async def close_all(self) -> None:
|
|
84
|
+
async with self._lock:
|
|
85
|
+
items = list(self.connections.items())
|
|
86
|
+
self.connections.clear()
|
|
87
|
+
|
|
88
|
+
for _, websocket in items:
|
|
89
|
+
try:
|
|
90
|
+
await websocket.close()
|
|
91
|
+
except Exception:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
async def _broadcast_json(self, message: dict[str, Any]) -> None:
|
|
95
|
+
async with self._lock:
|
|
96
|
+
items = list(self.connections.items())
|
|
97
|
+
|
|
98
|
+
disconnected: list[str] = []
|
|
99
|
+
for client_id, websocket in items:
|
|
100
|
+
try:
|
|
101
|
+
await websocket.send_json(message)
|
|
102
|
+
except Exception:
|
|
103
|
+
disconnected.append(client_id)
|
|
104
|
+
|
|
105
|
+
if disconnected:
|
|
106
|
+
async with self._lock:
|
|
107
|
+
for cid in disconnected:
|
|
108
|
+
self.connections.pop(cid, None)
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def generate_client_id() -> str:
|
|
112
|
+
return str(uuid.uuid4())
|
|
@@ -0,0 +1,547 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import contextvars
|
|
5
|
+
import logging
|
|
6
|
+
from contextlib import asynccontextmanager
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from inspect import Parameter, iscoroutinefunction, signature
|
|
9
|
+
from typing import Any, Callable, Literal, TypeVar
|
|
10
|
+
|
|
11
|
+
from fastapi import APIRouter, FastAPI, WebSocket, WebSocketDisconnect
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
|
|
14
|
+
from .connection_manager import ConnectionManager
|
|
15
|
+
from .errors import CommandError
|
|
16
|
+
from .persistence import PersistenceManager
|
|
17
|
+
from .proxy import StateProxy, SyncState
|
|
18
|
+
from .state_store import StateStore
|
|
19
|
+
from .stream_buffer import AppendBuffer, DeltaBuffer, ReplaceBuffer, StreamRef
|
|
20
|
+
|
|
21
|
+
T = TypeVar("T", bound=BaseModel)
|
|
22
|
+
_F = TypeVar("_F", bound=Callable[..., Any])
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True, slots=True)
|
|
27
|
+
class CommandContext:
|
|
28
|
+
client_id: str
|
|
29
|
+
request_id: str | None
|
|
30
|
+
command: str
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True, slots=True)
|
|
34
|
+
class PatchMetadata:
|
|
35
|
+
origin_client_id: str | None = None
|
|
36
|
+
request_id: str | None = None
|
|
37
|
+
command: str | None = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
_current_command_context: contextvars.ContextVar[CommandContext | None] = (
|
|
41
|
+
contextvars.ContextVar("lab_link_command_context", default=None)
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class StateTransaction:
|
|
46
|
+
def __init__(self, sync: "LabSync", meta: PatchMetadata) -> None:
|
|
47
|
+
self._sync = sync
|
|
48
|
+
self._meta = meta
|
|
49
|
+
self._changes: list[tuple[str, Any]] = []
|
|
50
|
+
self._closed = False
|
|
51
|
+
|
|
52
|
+
def __enter__(self) -> "StateTransaction":
|
|
53
|
+
return self
|
|
54
|
+
|
|
55
|
+
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
|
|
56
|
+
self._closed = True
|
|
57
|
+
if exc_type is None and self._changes:
|
|
58
|
+
self._sync._commit_changes(self._changes, self._meta)
|
|
59
|
+
|
|
60
|
+
def set(self, path: str, value: Any) -> None:
|
|
61
|
+
if self._closed:
|
|
62
|
+
raise RuntimeError("transaction is already closed")
|
|
63
|
+
self._changes.append((path, value))
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class LabSync:
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
prefix: str = "/sync",
|
|
70
|
+
persist: bool = False,
|
|
71
|
+
db_url: str = "sqlite:///lab_link.db",
|
|
72
|
+
compress: bool = False,
|
|
73
|
+
) -> None:
|
|
74
|
+
self._prefix = prefix.rstrip("/")
|
|
75
|
+
self._persist = persist
|
|
76
|
+
self._db_url = db_url
|
|
77
|
+
self._compress = compress
|
|
78
|
+
|
|
79
|
+
self._store: StateStore | None = None
|
|
80
|
+
self._commands: dict[str, Callable[..., Any]] = {}
|
|
81
|
+
self._updaters: list[tuple[Callable[..., Any], float]] = []
|
|
82
|
+
self._streams: dict[str, StreamRef] = {}
|
|
83
|
+
self._live_buffers: dict[str, AppendBuffer | ReplaceBuffer | DeltaBuffer] = {}
|
|
84
|
+
self._conn_manager: ConnectionManager | None = None
|
|
85
|
+
self._persistence: PersistenceManager | None = None
|
|
86
|
+
self._patch_queue: asyncio.Queue | None = None
|
|
87
|
+
self._router: APIRouter | None = None
|
|
88
|
+
self._pending_patch_tasks: set[asyncio.Task[None]] = set()
|
|
89
|
+
|
|
90
|
+
# ``sync.state`` is always a SyncState instance.
|
|
91
|
+
# Before @sync.state is applied it acts as the decorator (callable).
|
|
92
|
+
# After registration it delegates attribute access to the internal StateProxy.
|
|
93
|
+
self.state = SyncState(self._register_state_model)
|
|
94
|
+
|
|
95
|
+
# ── model registration ───────────────────────────────────────────────────
|
|
96
|
+
|
|
97
|
+
def _register_state_model(
|
|
98
|
+
self,
|
|
99
|
+
cls: type[BaseModel],
|
|
100
|
+
initial: BaseModel | dict[str, Any] | None = None,
|
|
101
|
+
) -> None:
|
|
102
|
+
if not issubclass(cls, BaseModel):
|
|
103
|
+
raise TypeError(f"{cls.__name__} must be a pydantic BaseModel subclass")
|
|
104
|
+
if initial is None:
|
|
105
|
+
initial = cls().model_dump(mode="json")
|
|
106
|
+
self._store = StateStore(cls, initial)
|
|
107
|
+
self._patch_queue = asyncio.Queue()
|
|
108
|
+
self.state._set_proxy(
|
|
109
|
+
StateProxy(self._store, self._patch_queue, self._metadata_from_current_context)
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def register_state(
|
|
113
|
+
self,
|
|
114
|
+
model_class: type[T],
|
|
115
|
+
*,
|
|
116
|
+
initial: T | dict[str, Any] | None = None,
|
|
117
|
+
) -> None:
|
|
118
|
+
self._register_state_model(model_class, initial)
|
|
119
|
+
|
|
120
|
+
# ── decorators ──────────────────────────────────────────────────────────
|
|
121
|
+
|
|
122
|
+
def command(self, fn: _F) -> _F:
|
|
123
|
+
"""@sync.command — registers fn under fn.__name__. Supports sync & async."""
|
|
124
|
+
self._commands[fn.__name__] = fn
|
|
125
|
+
return fn
|
|
126
|
+
|
|
127
|
+
def updater(self, interval: float = 1.0) -> Callable[[_F], _F]:
|
|
128
|
+
"""@sync.updater(interval=0.1) — registers a background polling coroutine."""
|
|
129
|
+
def decorator(fn: _F) -> _F:
|
|
130
|
+
self._updaters.append((fn, interval))
|
|
131
|
+
return fn
|
|
132
|
+
return decorator
|
|
133
|
+
|
|
134
|
+
# ── stream registration ──────────────────────────────────────────────────
|
|
135
|
+
|
|
136
|
+
def stream(
|
|
137
|
+
self,
|
|
138
|
+
id: str,
|
|
139
|
+
*,
|
|
140
|
+
mode: Literal["append", "replace", "int_delta"] = "replace",
|
|
141
|
+
capacity: int = 10_000,
|
|
142
|
+
dtype: Literal["float32", "float64", "json"] = "float32",
|
|
143
|
+
) -> StreamRef:
|
|
144
|
+
"""Register a named stream and return a StreamRef.
|
|
145
|
+
|
|
146
|
+
The ref is safe to store at module level and use in updaters — it
|
|
147
|
+
materialises the real buffer automatically when the app lifespan starts.
|
|
148
|
+
"""
|
|
149
|
+
ref = StreamRef(id, mode, capacity, dtype)
|
|
150
|
+
self._streams[id] = ref
|
|
151
|
+
if self._conn_manager is not None:
|
|
152
|
+
# Already inside lifespan — materialise immediately
|
|
153
|
+
buf = self._make_buffer(id, mode, capacity, dtype)
|
|
154
|
+
ref.materialize(buf)
|
|
155
|
+
self._live_buffers[id] = buf
|
|
156
|
+
return ref
|
|
157
|
+
|
|
158
|
+
def _make_buffer(
|
|
159
|
+
self, id: str, mode: str, capacity: int, dtype: str
|
|
160
|
+
) -> AppendBuffer | ReplaceBuffer | DeltaBuffer:
|
|
161
|
+
cm = self._conn_manager
|
|
162
|
+
if mode == "append":
|
|
163
|
+
return AppendBuffer(id, capacity, cm)
|
|
164
|
+
elif mode == "replace":
|
|
165
|
+
return ReplaceBuffer(id, capacity, dtype, cm)
|
|
166
|
+
elif mode == "int_delta":
|
|
167
|
+
return DeltaBuffer(id, capacity, cm)
|
|
168
|
+
else:
|
|
169
|
+
raise ValueError(f"Unknown stream mode: {mode!r}")
|
|
170
|
+
|
|
171
|
+
# ── state access ─────────────────────────────────────────────────────────
|
|
172
|
+
|
|
173
|
+
def get(self, path: str) -> Any:
|
|
174
|
+
"""Read helper: sync.get('pump/speed') → scalar value."""
|
|
175
|
+
if self._store is None:
|
|
176
|
+
raise RuntimeError("No @sync.state model registered")
|
|
177
|
+
return self._store.get(path)
|
|
178
|
+
|
|
179
|
+
def set(self, path: str, value: Any) -> tuple[list[dict[str, Any]], int]:
|
|
180
|
+
"""Set a JSON Pointer value, validate state, and broadcast one patch."""
|
|
181
|
+
return self._commit_changes([(path, value)], self._metadata_from_current_context())
|
|
182
|
+
|
|
183
|
+
def replace_state(self, state: BaseModel | dict[str, Any]) -> tuple[list[dict[str, Any]], int]:
|
|
184
|
+
if self._store is None:
|
|
185
|
+
raise RuntimeError("No sync state model registered")
|
|
186
|
+
patch, version = self._store.replace_state(state)
|
|
187
|
+
if patch:
|
|
188
|
+
self._schedule_patch_broadcast(
|
|
189
|
+
patch,
|
|
190
|
+
version,
|
|
191
|
+
self._metadata_from_current_context(),
|
|
192
|
+
)
|
|
193
|
+
return patch, version
|
|
194
|
+
|
|
195
|
+
def transaction(
|
|
196
|
+
self,
|
|
197
|
+
*,
|
|
198
|
+
origin: str | None = None,
|
|
199
|
+
request_id: str | None = None,
|
|
200
|
+
command: str | None = None,
|
|
201
|
+
) -> StateTransaction:
|
|
202
|
+
ctx = _current_command_context.get()
|
|
203
|
+
meta = PatchMetadata(
|
|
204
|
+
origin_client_id=origin if origin is not None else (ctx.client_id if ctx else None),
|
|
205
|
+
request_id=request_id if request_id is not None else (ctx.request_id if ctx else None),
|
|
206
|
+
command=command if command is not None else (ctx.command if ctx else None),
|
|
207
|
+
)
|
|
208
|
+
return StateTransaction(self, meta)
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def streams(self) -> dict[str, AppendBuffer | ReplaceBuffer | DeltaBuffer]:
|
|
212
|
+
return self._live_buffers
|
|
213
|
+
|
|
214
|
+
# ── FastAPI integration ───────────────────────────────────────────────────
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def router(self) -> APIRouter:
|
|
218
|
+
if self._router is not None:
|
|
219
|
+
return self._router
|
|
220
|
+
router = APIRouter(prefix=self._prefix)
|
|
221
|
+
|
|
222
|
+
@router.get("/state")
|
|
223
|
+
async def get_state() -> dict[str, Any]:
|
|
224
|
+
if self._store is None:
|
|
225
|
+
return {}
|
|
226
|
+
return self._store.snapshot()
|
|
227
|
+
|
|
228
|
+
@router.websocket("/ws")
|
|
229
|
+
async def ws_endpoint(websocket: WebSocket) -> None:
|
|
230
|
+
await self._handle_ws(websocket)
|
|
231
|
+
|
|
232
|
+
self._router = router
|
|
233
|
+
return router
|
|
234
|
+
|
|
235
|
+
@asynccontextmanager
|
|
236
|
+
async def lifespan(self, app: FastAPI | None = None):
|
|
237
|
+
"""Use in FastAPI lifespan to start drain task, updaters, persistence."""
|
|
238
|
+
self._conn_manager = ConnectionManager()
|
|
239
|
+
|
|
240
|
+
# Materialise all StreamRefs now that we have a conn_manager
|
|
241
|
+
for sid, ref in self._streams.items():
|
|
242
|
+
buf = self._make_buffer(sid, ref.mode, ref.capacity, ref.dtype)
|
|
243
|
+
ref.materialize(buf)
|
|
244
|
+
self._live_buffers[sid] = buf
|
|
245
|
+
|
|
246
|
+
# Recreate queue in running loop and rebind proxy
|
|
247
|
+
self._patch_queue = asyncio.Queue()
|
|
248
|
+
proxy = self.state._get_proxy()
|
|
249
|
+
if proxy is not None:
|
|
250
|
+
proxy._rebind_queue(self._patch_queue)
|
|
251
|
+
|
|
252
|
+
# Optional persistence
|
|
253
|
+
if self._persist and self._store is not None:
|
|
254
|
+
self._persistence = PersistenceManager(self._db_url)
|
|
255
|
+
saved = self._persistence.initialize()
|
|
256
|
+
if saved:
|
|
257
|
+
try:
|
|
258
|
+
self._store.replace_state(saved)
|
|
259
|
+
except Exception:
|
|
260
|
+
pass
|
|
261
|
+
|
|
262
|
+
tasks: list[asyncio.Task[None]] = []
|
|
263
|
+
|
|
264
|
+
if self._store is not None:
|
|
265
|
+
tasks.append(
|
|
266
|
+
asyncio.create_task(
|
|
267
|
+
_drain_patch_queue(
|
|
268
|
+
self._patch_queue,
|
|
269
|
+
self._store,
|
|
270
|
+
self._conn_manager,
|
|
271
|
+
self._persistence,
|
|
272
|
+
)
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
for fn, interval in self._updaters:
|
|
277
|
+
tasks.append(asyncio.create_task(_run_updater(fn, interval)))
|
|
278
|
+
|
|
279
|
+
yield
|
|
280
|
+
|
|
281
|
+
for task in tasks:
|
|
282
|
+
task.cancel()
|
|
283
|
+
for task in tasks:
|
|
284
|
+
try:
|
|
285
|
+
await task
|
|
286
|
+
except asyncio.CancelledError:
|
|
287
|
+
pass
|
|
288
|
+
|
|
289
|
+
if self._persistence and self._store:
|
|
290
|
+
self._persistence.save_sync(self._store.snapshot())
|
|
291
|
+
|
|
292
|
+
await self._conn_manager.close_all()
|
|
293
|
+
|
|
294
|
+
def create_app(self, **fastapi_kwargs: Any) -> FastAPI:
|
|
295
|
+
"""Convenience: creates FastAPI app with lifespan + router pre-wired."""
|
|
296
|
+
@asynccontextmanager
|
|
297
|
+
async def _lifespan(app: FastAPI):
|
|
298
|
+
async with self.lifespan(app):
|
|
299
|
+
yield
|
|
300
|
+
|
|
301
|
+
app = FastAPI(lifespan=_lifespan, **fastapi_kwargs)
|
|
302
|
+
app.include_router(self.router)
|
|
303
|
+
return app
|
|
304
|
+
|
|
305
|
+
# ── internal WebSocket handler ────────────────────────────────────────────
|
|
306
|
+
|
|
307
|
+
async def _handle_ws(self, websocket: WebSocket) -> None:
|
|
308
|
+
client_id = ConnectionManager.generate_client_id()
|
|
309
|
+
snapshot = self._store.snapshot() if self._store else {}
|
|
310
|
+
version = self._store.version() if self._store else 0
|
|
311
|
+
stream_snapshots = [
|
|
312
|
+
buf.snapshot_message()
|
|
313
|
+
for buf in self._live_buffers.values()
|
|
314
|
+
]
|
|
315
|
+
await self._conn_manager.connect(
|
|
316
|
+
websocket, client_id, snapshot, version, stream_snapshots
|
|
317
|
+
)
|
|
318
|
+
try:
|
|
319
|
+
while True:
|
|
320
|
+
data = await websocket.receive_json()
|
|
321
|
+
msg_type = data.get("type")
|
|
322
|
+
if msg_type == "command":
|
|
323
|
+
await self._dispatch_command(
|
|
324
|
+
websocket=websocket,
|
|
325
|
+
client_id=client_id,
|
|
326
|
+
command=str(data.get("command", "")),
|
|
327
|
+
params=dict(data.get("params") or {}),
|
|
328
|
+
request_id=data.get("requestId"),
|
|
329
|
+
)
|
|
330
|
+
elif msg_type == "stream_resync":
|
|
331
|
+
stream_id = data.get("id")
|
|
332
|
+
buf = self._live_buffers.get(stream_id)
|
|
333
|
+
if buf:
|
|
334
|
+
await self._conn_manager.send_to(client_id, buf.snapshot_message())
|
|
335
|
+
except WebSocketDisconnect:
|
|
336
|
+
pass
|
|
337
|
+
except Exception:
|
|
338
|
+
logger.exception("Unhandled WebSocket error for client %s", client_id)
|
|
339
|
+
finally:
|
|
340
|
+
await self._conn_manager.disconnect(client_id)
|
|
341
|
+
|
|
342
|
+
async def _dispatch_command(
|
|
343
|
+
self,
|
|
344
|
+
websocket: WebSocket,
|
|
345
|
+
client_id: str,
|
|
346
|
+
command: str,
|
|
347
|
+
params: dict[str, Any],
|
|
348
|
+
request_id: str | None,
|
|
349
|
+
) -> None:
|
|
350
|
+
handler = self._commands.get(command)
|
|
351
|
+
if handler is None:
|
|
352
|
+
if request_id:
|
|
353
|
+
await websocket.send_json(
|
|
354
|
+
{
|
|
355
|
+
"type": "command_error",
|
|
356
|
+
"command": command,
|
|
357
|
+
"requestId": request_id,
|
|
358
|
+
"code": "unknown_command",
|
|
359
|
+
"message": f"Unknown command: {command!r}",
|
|
360
|
+
"severity": "error",
|
|
361
|
+
"display": "toast",
|
|
362
|
+
"recoverable": False,
|
|
363
|
+
"originClientId": client_id,
|
|
364
|
+
"version": self._store.version() if self._store else 0,
|
|
365
|
+
}
|
|
366
|
+
)
|
|
367
|
+
return
|
|
368
|
+
|
|
369
|
+
ctx = CommandContext(client_id=client_id, request_id=request_id, command=command)
|
|
370
|
+
token = _current_command_context.set(ctx)
|
|
371
|
+
try:
|
|
372
|
+
if iscoroutinefunction(handler):
|
|
373
|
+
result = await self._call_handler(handler, ctx, params)
|
|
374
|
+
else:
|
|
375
|
+
result = self._call_handler(handler, ctx, params)
|
|
376
|
+
|
|
377
|
+
if self._patch_queue is not None:
|
|
378
|
+
await self._patch_queue.join()
|
|
379
|
+
await self._flush_pending_patch_tasks()
|
|
380
|
+
|
|
381
|
+
if request_id:
|
|
382
|
+
version = self._store.version() if self._store else 0
|
|
383
|
+
message: dict[str, Any] = {
|
|
384
|
+
"type": "command_ack",
|
|
385
|
+
"command": command,
|
|
386
|
+
"requestId": request_id,
|
|
387
|
+
"version": version,
|
|
388
|
+
}
|
|
389
|
+
if result is not None:
|
|
390
|
+
message["result"] = result
|
|
391
|
+
await websocket.send_json(message)
|
|
392
|
+
except CommandError as exc:
|
|
393
|
+
await self._send_command_error(websocket, exc, command, request_id, client_id)
|
|
394
|
+
except Exception as exc:
|
|
395
|
+
logger.exception("Command %r failed", command)
|
|
396
|
+
if request_id:
|
|
397
|
+
await self._send_command_error(
|
|
398
|
+
websocket,
|
|
399
|
+
CommandError(
|
|
400
|
+
code="command_failed",
|
|
401
|
+
message=str(exc) or "Command failed.",
|
|
402
|
+
detail=repr(exc),
|
|
403
|
+
recoverable=True,
|
|
404
|
+
),
|
|
405
|
+
command,
|
|
406
|
+
request_id,
|
|
407
|
+
client_id,
|
|
408
|
+
)
|
|
409
|
+
finally:
|
|
410
|
+
_current_command_context.reset(token)
|
|
411
|
+
|
|
412
|
+
def _call_handler(
|
|
413
|
+
self,
|
|
414
|
+
handler: Callable[..., Any],
|
|
415
|
+
ctx: CommandContext,
|
|
416
|
+
params: dict[str, Any],
|
|
417
|
+
) -> Any:
|
|
418
|
+
sig = signature(handler)
|
|
419
|
+
call_params = dict(params)
|
|
420
|
+
for name, param in sig.parameters.items():
|
|
421
|
+
if name in call_params:
|
|
422
|
+
continue
|
|
423
|
+
if (
|
|
424
|
+
name == "ctx"
|
|
425
|
+
or param.annotation is CommandContext
|
|
426
|
+
or param.annotation == "CommandContext"
|
|
427
|
+
):
|
|
428
|
+
call_params[name] = ctx
|
|
429
|
+
break
|
|
430
|
+
if param.kind in {Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD}:
|
|
431
|
+
continue
|
|
432
|
+
return handler(**call_params)
|
|
433
|
+
|
|
434
|
+
async def _send_command_error(
|
|
435
|
+
self,
|
|
436
|
+
websocket: WebSocket,
|
|
437
|
+
exc: CommandError,
|
|
438
|
+
command: str,
|
|
439
|
+
request_id: str | None,
|
|
440
|
+
client_id: str,
|
|
441
|
+
) -> None:
|
|
442
|
+
if not request_id:
|
|
443
|
+
return
|
|
444
|
+
await websocket.send_json(
|
|
445
|
+
exc.to_message(
|
|
446
|
+
command=command,
|
|
447
|
+
request_id=request_id,
|
|
448
|
+
version=self._store.version() if self._store else 0,
|
|
449
|
+
origin_client_id=client_id,
|
|
450
|
+
)
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
def _metadata_from_current_context(self) -> PatchMetadata:
|
|
454
|
+
ctx = _current_command_context.get()
|
|
455
|
+
if ctx is None:
|
|
456
|
+
return PatchMetadata()
|
|
457
|
+
return PatchMetadata(
|
|
458
|
+
origin_client_id=ctx.client_id,
|
|
459
|
+
request_id=ctx.request_id,
|
|
460
|
+
command=ctx.command,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
def _commit_changes(
|
|
464
|
+
self,
|
|
465
|
+
changes: list[tuple[str, Any]],
|
|
466
|
+
meta: PatchMetadata,
|
|
467
|
+
) -> tuple[list[dict[str, Any]], int]:
|
|
468
|
+
if self._store is None:
|
|
469
|
+
raise RuntimeError("No sync state model registered")
|
|
470
|
+
patch, version = self._store.apply_values(changes)
|
|
471
|
+
if patch:
|
|
472
|
+
self._schedule_patch_broadcast(patch, version, meta)
|
|
473
|
+
return patch, version
|
|
474
|
+
|
|
475
|
+
def _schedule_patch_broadcast(
|
|
476
|
+
self,
|
|
477
|
+
patch: list[dict[str, Any]],
|
|
478
|
+
version: int,
|
|
479
|
+
meta: PatchMetadata,
|
|
480
|
+
) -> None:
|
|
481
|
+
if self._conn_manager is None:
|
|
482
|
+
return
|
|
483
|
+
task = asyncio.create_task(self._broadcast_patch(patch, version, meta))
|
|
484
|
+
self._pending_patch_tasks.add(task)
|
|
485
|
+
task.add_done_callback(self._pending_patch_tasks.discard)
|
|
486
|
+
|
|
487
|
+
async def _broadcast_patch(
|
|
488
|
+
self,
|
|
489
|
+
patch: list[dict[str, Any]],
|
|
490
|
+
version: int,
|
|
491
|
+
meta: PatchMetadata,
|
|
492
|
+
) -> None:
|
|
493
|
+
if self._conn_manager is None:
|
|
494
|
+
return
|
|
495
|
+
await self._conn_manager.broadcast_patch(
|
|
496
|
+
patch,
|
|
497
|
+
version,
|
|
498
|
+
origin_client_id=meta.origin_client_id,
|
|
499
|
+
request_id=meta.request_id,
|
|
500
|
+
command=meta.command,
|
|
501
|
+
)
|
|
502
|
+
if self._persistence:
|
|
503
|
+
await self._persistence.save_debounced(self._store.snapshot())
|
|
504
|
+
|
|
505
|
+
async def _flush_pending_patch_tasks(self) -> None:
|
|
506
|
+
while self._pending_patch_tasks:
|
|
507
|
+
tasks = list(self._pending_patch_tasks)
|
|
508
|
+
await asyncio.gather(*tasks)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
# ── background tasks ──────────────────────────────────────────────────────────
|
|
512
|
+
|
|
513
|
+
async def _drain_patch_queue(
|
|
514
|
+
queue: asyncio.Queue,
|
|
515
|
+
store: StateStore,
|
|
516
|
+
conn_manager: ConnectionManager,
|
|
517
|
+
persistence: PersistenceManager | None,
|
|
518
|
+
) -> None:
|
|
519
|
+
while True:
|
|
520
|
+
item = await queue.get()
|
|
521
|
+
try:
|
|
522
|
+
if len(item) == 2:
|
|
523
|
+
path, value = item
|
|
524
|
+
meta = PatchMetadata()
|
|
525
|
+
else:
|
|
526
|
+
path, value, meta = item
|
|
527
|
+
patch, version = store.apply_value(path, value)
|
|
528
|
+
await conn_manager.broadcast_patch(
|
|
529
|
+
patch,
|
|
530
|
+
version,
|
|
531
|
+
origin_client_id=meta.origin_client_id,
|
|
532
|
+
request_id=meta.request_id,
|
|
533
|
+
command=meta.command,
|
|
534
|
+
)
|
|
535
|
+
if persistence:
|
|
536
|
+
await persistence.save_debounced(store.snapshot())
|
|
537
|
+
finally:
|
|
538
|
+
queue.task_done()
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
async def _run_updater(fn: Callable[..., Any], interval: float) -> None:
|
|
542
|
+
while True:
|
|
543
|
+
await asyncio.sleep(interval)
|
|
544
|
+
if iscoroutinefunction(fn):
|
|
545
|
+
await fn()
|
|
546
|
+
else:
|
|
547
|
+
fn()
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
Severity = Literal["info", "warning", "error"]
|
|
8
|
+
DisplayHint = Literal["toast", "banner", "inline"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(slots=True)
|
|
12
|
+
class CommandError(Exception):
|
|
13
|
+
code: str
|
|
14
|
+
message: str
|
|
15
|
+
detail: str | None = None
|
|
16
|
+
severity: Severity = "error"
|
|
17
|
+
display: DisplayHint = "toast"
|
|
18
|
+
path: str | None = None
|
|
19
|
+
recoverable: bool = True
|
|
20
|
+
|
|
21
|
+
def __post_init__(self) -> None:
|
|
22
|
+
super().__init__(self.message)
|
|
23
|
+
|
|
24
|
+
def to_message(
|
|
25
|
+
self,
|
|
26
|
+
*,
|
|
27
|
+
command: str,
|
|
28
|
+
request_id: str | None,
|
|
29
|
+
version: int,
|
|
30
|
+
origin_client_id: str | None = None,
|
|
31
|
+
) -> dict[str, Any]:
|
|
32
|
+
payload: dict[str, Any] = {
|
|
33
|
+
"type": "command_error",
|
|
34
|
+
"command": command,
|
|
35
|
+
"requestId": request_id,
|
|
36
|
+
"code": self.code,
|
|
37
|
+
"message": self.message,
|
|
38
|
+
"severity": self.severity,
|
|
39
|
+
"display": self.display,
|
|
40
|
+
"recoverable": self.recoverable,
|
|
41
|
+
"version": version,
|
|
42
|
+
}
|
|
43
|
+
if self.detail is not None:
|
|
44
|
+
payload["detail"] = self.detail
|
|
45
|
+
if self.path is not None:
|
|
46
|
+
payload["path"] = self.path
|
|
47
|
+
if origin_client_id is not None:
|
|
48
|
+
payload["originClientId"] = origin_client_id
|
|
49
|
+
return payload
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PersistenceManager:
|
|
9
|
+
"""Optional SQLite persistence using sqlmodel. Activated when LabSync(persist=True)."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, db_url: str = "sqlite:///lab_link.db") -> None:
|
|
12
|
+
self._db_url = db_url
|
|
13
|
+
self._engine = None
|
|
14
|
+
self._pending_save: asyncio.Task | None = None
|
|
15
|
+
self._debounce_seconds = 1.0
|
|
16
|
+
|
|
17
|
+
def initialize(self) -> dict[str, Any] | None:
|
|
18
|
+
"""Create tables, return persisted state dict if exists, else None."""
|
|
19
|
+
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
|
20
|
+
|
|
21
|
+
class LabSyncState(SQLModel, table=True):
|
|
22
|
+
id: int | None = Field(default=None, primary_key=True)
|
|
23
|
+
state_json: str
|
|
24
|
+
|
|
25
|
+
self._LabSyncState = LabSyncState
|
|
26
|
+
self._engine = create_engine(self._db_url)
|
|
27
|
+
SQLModel.metadata.create_all(self._engine)
|
|
28
|
+
|
|
29
|
+
with Session(self._engine) as session:
|
|
30
|
+
row = session.exec(select(LabSyncState)).first()
|
|
31
|
+
if row:
|
|
32
|
+
return json.loads(row.state_json)
|
|
33
|
+
return None
|
|
34
|
+
|
|
35
|
+
def save_sync(self, state: dict[str, Any]) -> None:
|
|
36
|
+
from sqlmodel import Session, select
|
|
37
|
+
|
|
38
|
+
LabSyncState = self._LabSyncState
|
|
39
|
+
with Session(self._engine) as session:
|
|
40
|
+
row = session.exec(select(LabSyncState)).first()
|
|
41
|
+
if row:
|
|
42
|
+
row.state_json = json.dumps(state)
|
|
43
|
+
else:
|
|
44
|
+
row = LabSyncState(state_json=json.dumps(state))
|
|
45
|
+
session.add(row)
|
|
46
|
+
session.commit()
|
|
47
|
+
|
|
48
|
+
async def save_debounced(self, state: dict[str, Any]) -> None:
|
|
49
|
+
"""Coalesce rapid saves; only persist after debounce_seconds of inactivity."""
|
|
50
|
+
if self._pending_save and not self._pending_save.done():
|
|
51
|
+
self._pending_save.cancel()
|
|
52
|
+
|
|
53
|
+
loop = asyncio.get_event_loop()
|
|
54
|
+
self._pending_save = loop.create_task(self._debounced_save(state))
|
|
55
|
+
|
|
56
|
+
async def _debounced_save(self, state: dict[str, Any]) -> None:
|
|
57
|
+
await asyncio.sleep(self._debounce_seconds)
|
|
58
|
+
self.save_sync(state)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def escape_pointer_part(part: object) -> str:
|
|
5
|
+
return str(part).replace("~", "~0").replace("/", "~1")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def ptr(*parts: object) -> str:
|
|
9
|
+
if not parts:
|
|
10
|
+
return ""
|
|
11
|
+
return "/" + "/".join(escape_pointer_part(part) for part in parts)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .state_store import StateStore
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class NestedProxy:
|
|
13
|
+
"""Accumulates path segments for nested attribute access. Write-only for mutations."""
|
|
14
|
+
|
|
15
|
+
__slots__ = ("_root", "_path")
|
|
16
|
+
|
|
17
|
+
def __init__(self, root: "StateProxy", path: list[str]) -> None:
|
|
18
|
+
object.__setattr__(self, "_root", root)
|
|
19
|
+
object.__setattr__(self, "_path", path)
|
|
20
|
+
|
|
21
|
+
def __getattr__(self, name: str) -> "NestedProxy":
|
|
22
|
+
if name.startswith("_"):
|
|
23
|
+
raise AttributeError(name)
|
|
24
|
+
path = object.__getattribute__(self, "_path")
|
|
25
|
+
root = object.__getattribute__(self, "_root")
|
|
26
|
+
return NestedProxy(root, path + [name])
|
|
27
|
+
|
|
28
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
29
|
+
path = object.__getattribute__(self, "_path")
|
|
30
|
+
root = object.__getattribute__(self, "_root")
|
|
31
|
+
full_path = "/" + "/".join(path + [name])
|
|
32
|
+
root._enqueue(full_path, value)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class StateProxy:
|
|
36
|
+
"""
|
|
37
|
+
Internal proxy held by SyncState. Enqueues ``(json_path, value)`` onto the drain queue
|
|
38
|
+
on every attribute write. Not exposed directly — use ``sync.state`` instead.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
__slots__ = ("_store", "_queue", "_metadata_getter")
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
store: "StateStore",
|
|
46
|
+
queue: asyncio.Queue,
|
|
47
|
+
metadata_getter: Callable[[], Any] | None = None,
|
|
48
|
+
) -> None:
|
|
49
|
+
object.__setattr__(self, "_store", store)
|
|
50
|
+
object.__setattr__(self, "_queue", queue)
|
|
51
|
+
object.__setattr__(self, "_metadata_getter", metadata_getter)
|
|
52
|
+
|
|
53
|
+
def __getattr__(self, name: str) -> Any:
|
|
54
|
+
if name.startswith("_"):
|
|
55
|
+
raise AttributeError(name)
|
|
56
|
+
store: "StateStore" = object.__getattribute__(self, "_store")
|
|
57
|
+
val = store.snapshot().get(name)
|
|
58
|
+
if isinstance(val, dict):
|
|
59
|
+
return NestedProxy(self, [name])
|
|
60
|
+
return val
|
|
61
|
+
|
|
62
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
63
|
+
self._enqueue(f"/{name}", value)
|
|
64
|
+
|
|
65
|
+
def _enqueue(self, path: str, value: Any) -> None:
|
|
66
|
+
queue: asyncio.Queue = object.__getattribute__(self, "_queue")
|
|
67
|
+
metadata_getter: Callable[[], Any] | None = object.__getattribute__(
|
|
68
|
+
self,
|
|
69
|
+
"_metadata_getter",
|
|
70
|
+
)
|
|
71
|
+
if metadata_getter is None:
|
|
72
|
+
queue.put_nowait((path, value))
|
|
73
|
+
else:
|
|
74
|
+
queue.put_nowait((path, value, metadata_getter()))
|
|
75
|
+
|
|
76
|
+
def _rebind_queue(self, queue: asyncio.Queue) -> None:
|
|
77
|
+
object.__setattr__(self, "_queue", queue)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class SyncState:
|
|
81
|
+
"""
|
|
82
|
+
The object behind ``sync.state``. Dual-purpose:
|
|
83
|
+
|
|
84
|
+
- ``@sync.state`` — callable; registers a Pydantic BaseModel class and returns it.
|
|
85
|
+
- ``sync.state.x = 5`` — after registration, delegates to the internal StateProxy
|
|
86
|
+
which enqueues ``("/x", 5)`` onto the drain queue.
|
|
87
|
+
|
|
88
|
+
Having a single concrete type for both roles lets type checkers (Pylance, mypy)
|
|
89
|
+
understand ``@sync.state`` as a normal callable decorator that returns the class.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
__slots__ = ("_proxy", "_on_register")
|
|
93
|
+
|
|
94
|
+
def __init__(self, on_register: Callable[[type], None]) -> None:
|
|
95
|
+
object.__setattr__(self, "_proxy", None)
|
|
96
|
+
object.__setattr__(self, "_on_register", on_register)
|
|
97
|
+
|
|
98
|
+
def __call__(self, cls: type[T]) -> type[T]:
|
|
99
|
+
"""Register ``cls`` as the sync state model. Used as ``@sync.state``."""
|
|
100
|
+
on_reg: Callable[[type], None] = object.__getattribute__(self, "_on_register")
|
|
101
|
+
on_reg(cls)
|
|
102
|
+
return cls
|
|
103
|
+
|
|
104
|
+
def __getattr__(self, name: str) -> Any:
|
|
105
|
+
if name.startswith("_"):
|
|
106
|
+
raise AttributeError(name)
|
|
107
|
+
proxy: StateProxy | None = object.__getattribute__(self, "_proxy")
|
|
108
|
+
if proxy is None:
|
|
109
|
+
raise RuntimeError("No @sync.state model registered yet")
|
|
110
|
+
return getattr(proxy, name)
|
|
111
|
+
|
|
112
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
113
|
+
proxy: StateProxy | None = object.__getattribute__(self, "_proxy")
|
|
114
|
+
if proxy is None:
|
|
115
|
+
raise RuntimeError("No @sync.state model registered yet")
|
|
116
|
+
setattr(proxy, name, value)
|
|
117
|
+
|
|
118
|
+
def _set_proxy(self, proxy: StateProxy) -> None:
|
|
119
|
+
object.__setattr__(self, "_proxy", proxy)
|
|
120
|
+
|
|
121
|
+
def _get_proxy(self) -> StateProxy | None:
|
|
122
|
+
return object.__getattribute__(self, "_proxy")
|
|
File without changes
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import json
|
|
5
|
+
import threading
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from typing import Any, Generic, TypeVar
|
|
8
|
+
|
|
9
|
+
import jsonpatch
|
|
10
|
+
from pydantic import BaseModel, ValidationError
|
|
11
|
+
|
|
12
|
+
T = TypeVar("T", bound=BaseModel)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class StateStore(Generic[T]):
|
|
16
|
+
def __init__(self, model_class: type[T], initial: BaseModel | dict[str, Any]) -> None:
|
|
17
|
+
self._model_class = model_class
|
|
18
|
+
self._lock = threading.RLock()
|
|
19
|
+
self._state: dict[str, Any] = {}
|
|
20
|
+
self._version: int = 0
|
|
21
|
+
self._replace_internal(initial)
|
|
22
|
+
|
|
23
|
+
def _replace_internal(self, state: BaseModel | dict[str, Any]) -> None:
|
|
24
|
+
validated = self._model_class.model_validate(state)
|
|
25
|
+
self._state = validated.model_dump(mode="json")
|
|
26
|
+
|
|
27
|
+
def snapshot(self) -> dict[str, Any]:
|
|
28
|
+
with self._lock:
|
|
29
|
+
return json.loads(json.dumps(self._state))
|
|
30
|
+
|
|
31
|
+
def version(self) -> int:
|
|
32
|
+
with self._lock:
|
|
33
|
+
return self._version
|
|
34
|
+
|
|
35
|
+
def get(self, path: str) -> Any:
|
|
36
|
+
with self._lock:
|
|
37
|
+
keys = _parse_pointer(path)
|
|
38
|
+
current: Any = self._state
|
|
39
|
+
for key in keys:
|
|
40
|
+
current = _get_child(current, key, path)
|
|
41
|
+
return copy.deepcopy(current)
|
|
42
|
+
|
|
43
|
+
def apply_value(self, json_path: str, value: Any) -> tuple[list[dict[str, Any]], int]:
|
|
44
|
+
return self.apply_values([(json_path, value)])
|
|
45
|
+
|
|
46
|
+
def apply_values(
|
|
47
|
+
self,
|
|
48
|
+
changes: Sequence[tuple[str, Any]],
|
|
49
|
+
) -> tuple[list[dict[str, Any]], int]:
|
|
50
|
+
with self._lock:
|
|
51
|
+
if not changes:
|
|
52
|
+
raise ValueError("at least one change is required")
|
|
53
|
+
|
|
54
|
+
old_state = json.loads(json.dumps(self._state))
|
|
55
|
+
next_state = json.loads(json.dumps(self._state))
|
|
56
|
+
|
|
57
|
+
for json_path, value in changes:
|
|
58
|
+
_set_pointer_value(next_state, json_path, value)
|
|
59
|
+
|
|
60
|
+
validated = self._validate_state(next_state)
|
|
61
|
+
self._state = validated
|
|
62
|
+
self._version += 1
|
|
63
|
+
patch = jsonpatch.make_patch(old_state, self._state)
|
|
64
|
+
return list(patch), self._version
|
|
65
|
+
|
|
66
|
+
def replace_state(self, state: BaseModel | dict[str, Any]) -> tuple[list[dict[str, Any]], int]:
|
|
67
|
+
with self._lock:
|
|
68
|
+
old_state = json.loads(json.dumps(self._state))
|
|
69
|
+
self._replace_internal(state)
|
|
70
|
+
self._version += 1
|
|
71
|
+
patch = jsonpatch.make_patch(old_state, self._state)
|
|
72
|
+
return list(patch), self._version
|
|
73
|
+
|
|
74
|
+
def _validate_current_state(self) -> None:
|
|
75
|
+
self._state = self._validate_state(self._state)
|
|
76
|
+
|
|
77
|
+
def _validate_state(self, state: Any) -> dict[str, Any]:
|
|
78
|
+
try:
|
|
79
|
+
validated = self._model_class.model_validate(state)
|
|
80
|
+
except ValidationError as exc:
|
|
81
|
+
raise ValueError(f"state validation failed: {exc}") from exc
|
|
82
|
+
return validated.model_dump(mode="json")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _parse_pointer(path: str) -> list[str]:
|
|
86
|
+
if path == "":
|
|
87
|
+
return []
|
|
88
|
+
if not path.startswith("/"):
|
|
89
|
+
path = "/" + path
|
|
90
|
+
return [_unescape_pointer_part(part) for part in path.split("/")[1:]]
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _unescape_pointer_part(part: str) -> str:
|
|
94
|
+
result = []
|
|
95
|
+
i = 0
|
|
96
|
+
while i < len(part):
|
|
97
|
+
char = part[i]
|
|
98
|
+
if char == "~":
|
|
99
|
+
if i + 1 >= len(part) or part[i + 1] not in {"0", "1"}:
|
|
100
|
+
raise ValueError(f"invalid JSON Pointer escape in segment {part!r}")
|
|
101
|
+
result.append("~" if part[i + 1] == "0" else "/")
|
|
102
|
+
i += 2
|
|
103
|
+
else:
|
|
104
|
+
result.append(char)
|
|
105
|
+
i += 1
|
|
106
|
+
return "".join(result)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _get_child(current: Any, key: str, full_path: str) -> Any:
|
|
110
|
+
if isinstance(current, dict):
|
|
111
|
+
if key not in current:
|
|
112
|
+
raise KeyError(f"path {full_path!r} does not exist")
|
|
113
|
+
return current[key]
|
|
114
|
+
if isinstance(current, list):
|
|
115
|
+
index = _parse_list_index(key, len(current), full_path)
|
|
116
|
+
return current[index]
|
|
117
|
+
raise TypeError(f"path {full_path!r} traverses non-container value")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _set_pointer_value(state: Any, path: str, value: Any) -> None:
|
|
121
|
+
keys = _parse_pointer(path)
|
|
122
|
+
if not keys:
|
|
123
|
+
raise ValueError("path must not be empty; use replace_state() for root replacement")
|
|
124
|
+
|
|
125
|
+
current = state
|
|
126
|
+
for key in keys[:-1]:
|
|
127
|
+
current = _get_child(current, key, path)
|
|
128
|
+
|
|
129
|
+
last = keys[-1]
|
|
130
|
+
if isinstance(current, dict):
|
|
131
|
+
if last not in current:
|
|
132
|
+
raise KeyError(f"path {path!r} does not exist")
|
|
133
|
+
current[last] = value
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
if isinstance(current, list):
|
|
137
|
+
if last == "-":
|
|
138
|
+
current.append(value)
|
|
139
|
+
return
|
|
140
|
+
index = _parse_list_index(last, len(current), path)
|
|
141
|
+
current[index] = value
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
raise TypeError(f"path {path!r} targets non-container value")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _parse_list_index(segment: str, length: int, full_path: str) -> int:
|
|
148
|
+
if not segment.isdecimal():
|
|
149
|
+
raise TypeError(f"path {full_path!r} uses non-numeric list index {segment!r}")
|
|
150
|
+
index = int(segment)
|
|
151
|
+
if index >= length:
|
|
152
|
+
raise IndexError(f"path {full_path!r} list index {index} out of range")
|
|
153
|
+
return index
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import struct
|
|
4
|
+
from collections import deque
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from .connection_manager import ConnectionManager
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AppendBuffer:
|
|
12
|
+
"""Ring buffer for time-series streams (sensor history, logs)."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, id: str, capacity: int, conn_manager: "ConnectionManager") -> None:
|
|
15
|
+
self.id = id
|
|
16
|
+
self._capacity = capacity
|
|
17
|
+
self._conn_manager = conn_manager
|
|
18
|
+
self._buffer: deque[Any] = deque(maxlen=capacity)
|
|
19
|
+
self._seq: int = 0
|
|
20
|
+
|
|
21
|
+
async def append(self, point: Any) -> None:
|
|
22
|
+
self._buffer.append(point)
|
|
23
|
+
self._seq += 1
|
|
24
|
+
await self._conn_manager.broadcast_json(
|
|
25
|
+
{"type": "stream_append", "id": self.id, "data": [point], "seq": self._seq}
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
async def extend(self, points: list[Any]) -> None:
|
|
29
|
+
self._buffer.extend(points)
|
|
30
|
+
self._seq += 1
|
|
31
|
+
await self._conn_manager.broadcast_json(
|
|
32
|
+
{"type": "stream_append", "id": self.id, "data": points, "seq": self._seq}
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def snapshot_message(self) -> dict[str, Any]:
|
|
36
|
+
return {
|
|
37
|
+
"type": "stream_snapshot",
|
|
38
|
+
"id": self.id,
|
|
39
|
+
"data": list(self._buffer),
|
|
40
|
+
"seq": self._seq,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ReplaceBuffer:
|
|
45
|
+
"""Full-replace buffer for float arrays (FFT, KDE, waveform)."""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
id: str,
|
|
50
|
+
capacity: int,
|
|
51
|
+
dtype: str,
|
|
52
|
+
conn_manager: "ConnectionManager",
|
|
53
|
+
) -> None:
|
|
54
|
+
self.id = id
|
|
55
|
+
self._capacity = capacity
|
|
56
|
+
self._dtype = dtype
|
|
57
|
+
self._conn_manager = conn_manager
|
|
58
|
+
self._data: list[float] = []
|
|
59
|
+
self._seq: int = 0
|
|
60
|
+
|
|
61
|
+
async def replace(self, data: "list[float] | Any") -> None:
|
|
62
|
+
try:
|
|
63
|
+
import numpy as np
|
|
64
|
+
if isinstance(data, np.ndarray):
|
|
65
|
+
data = data.tolist()
|
|
66
|
+
except ImportError:
|
|
67
|
+
pass
|
|
68
|
+
self._data = list(data)
|
|
69
|
+
self._seq += 1
|
|
70
|
+
|
|
71
|
+
if self._dtype in ("float32", "float64"):
|
|
72
|
+
frame = self._encode_binary(self._data)
|
|
73
|
+
await self._conn_manager.broadcast_binary(frame)
|
|
74
|
+
else:
|
|
75
|
+
await self._conn_manager.broadcast_json(
|
|
76
|
+
{"type": "stream_replace", "id": self.id, "data": self._data, "seq": self._seq}
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def _encode_binary(self, data: list[float]) -> bytes:
|
|
80
|
+
id_bytes = self.id.encode("utf-8")
|
|
81
|
+
header = struct.pack("<BH", 0x01, len(id_bytes)) + id_bytes + struct.pack("<I", self._seq)
|
|
82
|
+
fmt = "<" + ("f" if self._dtype == "float32" else "d") * len(data)
|
|
83
|
+
payload = struct.pack(fmt, *data)
|
|
84
|
+
return header + payload
|
|
85
|
+
|
|
86
|
+
def snapshot_message(self) -> dict[str, Any]:
|
|
87
|
+
return {
|
|
88
|
+
"type": "stream_snapshot",
|
|
89
|
+
"id": self.id,
|
|
90
|
+
"data": self._data,
|
|
91
|
+
"seq": self._seq,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class DeltaBuffer:
|
|
96
|
+
"""Sparse-delta buffer for integer-count histograms."""
|
|
97
|
+
|
|
98
|
+
def __init__(self, id: str, num_bins: int, conn_manager: "ConnectionManager") -> None:
|
|
99
|
+
self.id = id
|
|
100
|
+
self._num_bins = num_bins
|
|
101
|
+
self._conn_manager = conn_manager
|
|
102
|
+
self._bins: list[int] = [0] * num_bins
|
|
103
|
+
self._seq: int = 0
|
|
104
|
+
|
|
105
|
+
async def apply_delta(self, deltas: dict[int, int]) -> None:
|
|
106
|
+
for idx, delta in deltas.items():
|
|
107
|
+
self._bins[idx] += delta
|
|
108
|
+
self._seq += 1
|
|
109
|
+
sparse = [[idx, delta] for idx, delta in deltas.items() if delta != 0]
|
|
110
|
+
await self._conn_manager.broadcast_json(
|
|
111
|
+
{"type": "stream_delta", "id": self.id, "deltas": sparse, "seq": self._seq}
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
async def replace(self, bins: list[int]) -> None:
|
|
115
|
+
self._bins = list(bins)
|
|
116
|
+
self._seq += 1
|
|
117
|
+
await self._conn_manager.broadcast_json(self.snapshot_message())
|
|
118
|
+
|
|
119
|
+
def snapshot_message(self) -> dict[str, Any]:
|
|
120
|
+
return {
|
|
121
|
+
"type": "stream_snapshot",
|
|
122
|
+
"id": self.id,
|
|
123
|
+
"data": list(self._bins),
|
|
124
|
+
"seq": self._seq,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class StreamRef:
|
|
129
|
+
"""
|
|
130
|
+
Lazy reference returned by ``sync.stream()``.
|
|
131
|
+
|
|
132
|
+
Holds the stream spec (id, mode, capacity, dtype) until the app lifespan starts
|
|
133
|
+
and materialises the real buffer. After that all method calls are forwarded.
|
|
134
|
+
|
|
135
|
+
This lets user code store the return value of ``sync.stream()`` at module level
|
|
136
|
+
and call ``await buf.append(...)`` inside updaters without worrying about
|
|
137
|
+
whether the lifespan has started yet.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
id: str,
|
|
143
|
+
mode: Literal["append", "replace", "int_delta"],
|
|
144
|
+
capacity: int,
|
|
145
|
+
dtype: Literal["float32", "float64", "json"],
|
|
146
|
+
) -> None:
|
|
147
|
+
self.id = id
|
|
148
|
+
self.mode = mode
|
|
149
|
+
self.capacity = capacity
|
|
150
|
+
self.dtype = dtype
|
|
151
|
+
self._buffer: AppendBuffer | ReplaceBuffer | DeltaBuffer | None = None
|
|
152
|
+
|
|
153
|
+
def materialize(self, buffer: AppendBuffer | ReplaceBuffer | DeltaBuffer) -> None:
|
|
154
|
+
self._buffer = buffer
|
|
155
|
+
|
|
156
|
+
def _require(self) -> AppendBuffer | ReplaceBuffer | DeltaBuffer:
|
|
157
|
+
if self._buffer is None:
|
|
158
|
+
raise RuntimeError(
|
|
159
|
+
f"Stream {self.id!r} is not yet active — "
|
|
160
|
+
"ensure it is accessed after the app lifespan has started."
|
|
161
|
+
)
|
|
162
|
+
return self._buffer
|
|
163
|
+
|
|
164
|
+
# ── AppendBuffer interface ────────────────────────────────────────────────
|
|
165
|
+
|
|
166
|
+
async def append(self, point: Any) -> None:
|
|
167
|
+
buf = self._require()
|
|
168
|
+
assert isinstance(buf, AppendBuffer)
|
|
169
|
+
await buf.append(point)
|
|
170
|
+
|
|
171
|
+
async def extend(self, points: list[Any]) -> None:
|
|
172
|
+
buf = self._require()
|
|
173
|
+
assert isinstance(buf, AppendBuffer)
|
|
174
|
+
await buf.extend(points)
|
|
175
|
+
|
|
176
|
+
# ── ReplaceBuffer interface ───────────────────────────────────────────────
|
|
177
|
+
|
|
178
|
+
async def replace(self, data: Any) -> None:
|
|
179
|
+
buf = self._require()
|
|
180
|
+
assert isinstance(buf, ReplaceBuffer)
|
|
181
|
+
await buf.replace(data)
|
|
182
|
+
|
|
183
|
+
# ── DeltaBuffer interface ─────────────────────────────────────────────────
|
|
184
|
+
|
|
185
|
+
async def apply_delta(self, deltas: dict[int, int]) -> None:
|
|
186
|
+
buf = self._require()
|
|
187
|
+
assert isinstance(buf, DeltaBuffer)
|
|
188
|
+
await buf.apply_delta(deltas)
|
|
189
|
+
|
|
190
|
+
def snapshot_message(self) -> dict[str, Any]:
|
|
191
|
+
return self._require().snapshot_message()
|