pycrdt 0.12.12__cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl → 0.12.45__cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pycrdt might be problematic. Click here for more details.
- pycrdt/__init__.py +11 -0
- pycrdt/_array.py +3 -3
- pycrdt/_awareness.py +25 -1
- pycrdt/_base.py +44 -7
- pycrdt/_doc.py +116 -11
- pycrdt/_map.py +6 -4
- pycrdt/_provider.py +176 -0
- pycrdt/_pycrdt.cpython-311-powerpc64le-linux-gnu.so +0 -0
- pycrdt/_pycrdt.pyi +96 -6
- pycrdt/_snapshot.py +56 -0
- pycrdt/_sticky_index.py +145 -0
- pycrdt/_text.py +2 -2
- pycrdt/_transaction.py +46 -15
- pycrdt/_undo.py +11 -1
- pycrdt/_version.py +1 -7
- pycrdt/_xml.py +8 -3
- pycrdt-0.12.45.dist-info/METADATA +35 -0
- pycrdt-0.12.45.dist-info/RECORD +23 -0
- pycrdt-0.12.45.dist-info/WHEEL +5 -0
- pycrdt-0.12.12.dist-info/METADATA +0 -47
- pycrdt-0.12.12.dist-info/RECORD +0 -20
- pycrdt-0.12.12.dist-info/WHEEL +0 -4
- {pycrdt-0.12.12.dist-info → pycrdt-0.12.45.dist-info}/licenses/LICENSE +0 -0
pycrdt/__init__.py
CHANGED
|
@@ -1,16 +1,27 @@
|
|
|
1
|
+
from pkgutil import extend_path
|
|
2
|
+
|
|
3
|
+
__path__ = extend_path(__path__, __name__)
|
|
4
|
+
|
|
1
5
|
from ._array import Array as Array
|
|
2
6
|
from ._array import ArrayEvent as ArrayEvent
|
|
3
7
|
from ._array import TypedArray as TypedArray
|
|
4
8
|
from ._awareness import Awareness as Awareness
|
|
9
|
+
from ._awareness import is_awareness_disconnect_message as is_awareness_disconnect_message
|
|
5
10
|
from ._doc import Doc as Doc
|
|
6
11
|
from ._doc import TypedDoc as TypedDoc
|
|
7
12
|
from ._map import Map as Map
|
|
8
13
|
from ._map import MapEvent as MapEvent
|
|
9
14
|
from ._map import TypedMap as TypedMap
|
|
15
|
+
from ._provider import Channel as Channel
|
|
16
|
+
from ._provider import Provider as Provider
|
|
17
|
+
from ._pycrdt import DeleteSet as DeleteSet
|
|
10
18
|
from ._pycrdt import StackItem as StackItem
|
|
11
19
|
from ._pycrdt import SubdocsEvent as SubdocsEvent
|
|
12
20
|
from ._pycrdt import Subscription as Subscription
|
|
13
21
|
from ._pycrdt import TransactionEvent as TransactionEvent
|
|
22
|
+
from ._snapshot import Snapshot as Snapshot
|
|
23
|
+
from ._sticky_index import Assoc as Assoc
|
|
24
|
+
from ._sticky_index import StickyIndex as StickyIndex
|
|
14
25
|
from ._sync import Decoder as Decoder
|
|
15
26
|
from ._sync import Encoder as Encoder
|
|
16
27
|
from ._sync import YMessageType as YMessageType
|
pycrdt/_array.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, cast, overload
|
|
4
4
|
|
|
5
|
-
from ._base import BaseDoc, BaseEvent, BaseType, Typed, base_types, event_types
|
|
5
|
+
from ._base import BaseDoc, BaseEvent, BaseType, Sequence, Typed, base_types, event_types
|
|
6
6
|
from ._pycrdt import Array as _Array
|
|
7
7
|
from ._pycrdt import ArrayEvent as _ArrayEvent
|
|
8
8
|
from ._pycrdt import Subscription
|
|
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|
|
13
13
|
T = TypeVar("T")
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
class Array(
|
|
16
|
+
class Array(Sequence, Generic[T]):
|
|
17
17
|
"""
|
|
18
18
|
A collection used to store data in an indexed sequence structure, similar to a Python `list`.
|
|
19
19
|
"""
|
|
@@ -399,7 +399,7 @@ class ArrayIterator:
|
|
|
399
399
|
self.idx = 0
|
|
400
400
|
|
|
401
401
|
def __iter__(self) -> ArrayIterator:
|
|
402
|
-
return self
|
|
402
|
+
return self # pragma: nocover
|
|
403
403
|
|
|
404
404
|
def __next__(self) -> Any:
|
|
405
405
|
if self.idx == self.length:
|
pycrdt/_awareness.py
CHANGED
|
@@ -10,7 +10,7 @@ from anyio import TASK_STATUS_IGNORED, create_task_group, sleep
|
|
|
10
10
|
from anyio.abc import TaskGroup, TaskStatus
|
|
11
11
|
|
|
12
12
|
from ._doc import Doc
|
|
13
|
-
from ._sync import Decoder, Encoder
|
|
13
|
+
from ._sync import Decoder, Encoder, read_message
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class Awareness:
|
|
@@ -278,3 +278,27 @@ class Awareness:
|
|
|
278
278
|
id: The subscription ID to unregister.
|
|
279
279
|
"""
|
|
280
280
|
del self._subscriptions[id]
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def is_awareness_disconnect_message(message: bytes) -> bool:
|
|
284
|
+
"""
|
|
285
|
+
Check if the message is null, which means that it is a disconnection message
|
|
286
|
+
from the client.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
message: The message received from the client.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Whether the message is a disconnection message or not.
|
|
293
|
+
"""
|
|
294
|
+
decoder = Decoder(read_message(message))
|
|
295
|
+
length = decoder.read_var_uint()
|
|
296
|
+
# A disconnection message should be a single message
|
|
297
|
+
if length == 1:
|
|
298
|
+
# Remove client_id and clock information from message (not used)
|
|
299
|
+
for _ in range(2):
|
|
300
|
+
decoder.read_var_uint()
|
|
301
|
+
state = decoder.read_var_string()
|
|
302
|
+
if state == "null":
|
|
303
|
+
return True
|
|
304
|
+
return False
|
pycrdt/_base.py
CHANGED
|
@@ -4,6 +4,7 @@ import threading
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from functools import lru_cache, partial
|
|
6
6
|
from inspect import signature
|
|
7
|
+
from types import UnionType
|
|
7
8
|
from typing import (
|
|
8
9
|
TYPE_CHECKING,
|
|
9
10
|
Any,
|
|
@@ -12,17 +13,21 @@ from typing import (
|
|
|
12
13
|
Type,
|
|
13
14
|
Union,
|
|
14
15
|
cast,
|
|
16
|
+
get_args,
|
|
17
|
+
get_origin,
|
|
15
18
|
get_type_hints,
|
|
16
19
|
overload,
|
|
17
20
|
)
|
|
18
21
|
|
|
19
22
|
import anyio
|
|
20
23
|
from anyio import BrokenResourceError, create_memory_object_stream
|
|
24
|
+
from anyio.abc import TaskGroup
|
|
21
25
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
22
26
|
|
|
23
27
|
from ._pycrdt import Doc as _Doc
|
|
24
28
|
from ._pycrdt import Subscription
|
|
25
29
|
from ._pycrdt import Transaction as _Transaction
|
|
30
|
+
from ._sticky_index import Assoc, StickyIndex
|
|
26
31
|
from ._transaction import ReadTransaction, Transaction
|
|
27
32
|
|
|
28
33
|
if TYPE_CHECKING:
|
|
@@ -42,17 +47,20 @@ class BaseDoc:
|
|
|
42
47
|
_doc: _Doc
|
|
43
48
|
_twin_doc: BaseDoc | None
|
|
44
49
|
_txn: Transaction | None
|
|
50
|
+
_exceptions: list[Exception]
|
|
45
51
|
_txn_lock: threading.Lock
|
|
46
52
|
_txn_async_lock: anyio.Lock
|
|
47
53
|
_allow_multithreading: bool
|
|
48
54
|
_Model: Any
|
|
49
55
|
_subscriptions: list[Subscription]
|
|
50
56
|
_origins: dict[int, Any]
|
|
57
|
+
_task_group: TaskGroup | None
|
|
51
58
|
|
|
52
59
|
def __init__(
|
|
53
60
|
self,
|
|
54
61
|
*,
|
|
55
62
|
client_id: int | None = None,
|
|
63
|
+
skip_gc: bool | None = None,
|
|
56
64
|
doc: _Doc | None = None,
|
|
57
65
|
Model=None,
|
|
58
66
|
allow_multithreading: bool = False,
|
|
@@ -60,15 +68,17 @@ class BaseDoc:
|
|
|
60
68
|
) -> None:
|
|
61
69
|
super().__init__(**data)
|
|
62
70
|
if doc is None:
|
|
63
|
-
doc = _Doc(client_id)
|
|
71
|
+
doc = _Doc(client_id, skip_gc)
|
|
64
72
|
self._doc = doc
|
|
65
73
|
self._txn = None
|
|
74
|
+
self._exceptions = []
|
|
66
75
|
self._txn_lock = threading.Lock()
|
|
67
76
|
self._txn_async_lock = anyio.Lock()
|
|
68
77
|
self._Model = Model
|
|
69
78
|
self._subscriptions = []
|
|
70
79
|
self._origins = {}
|
|
71
80
|
self._allow_multithreading = allow_multithreading
|
|
81
|
+
self._task_group = None
|
|
72
82
|
|
|
73
83
|
|
|
74
84
|
class BaseType(ABC):
|
|
@@ -261,11 +271,30 @@ class BaseType(ABC):
|
|
|
261
271
|
except BrokenResourceError:
|
|
262
272
|
to_remove.append(send_stream)
|
|
263
273
|
for send_stream in to_remove:
|
|
274
|
+
send_stream.close()
|
|
264
275
|
send_streams.remove(send_stream)
|
|
265
276
|
if not send_streams:
|
|
266
277
|
self.unobserve(self._event_subscription[deep])
|
|
267
278
|
|
|
268
279
|
|
|
280
|
+
class Sequence(BaseType):
|
|
281
|
+
def sticky_index(self, index: int, assoc: Assoc = Assoc.AFTER) -> StickyIndex:
|
|
282
|
+
"""
|
|
283
|
+
A permanent position that sticks to the same place even when
|
|
284
|
+
concurrent updates are made.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
index: The index at which to stick.
|
|
288
|
+
assoc: The [Assoc][pycrdt.Assoc] specifying whether to stick to the location
|
|
289
|
+
before or after the index.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
A [StickyIndex][pycrdt.StickyIndex] that can be used to retrieve the index after
|
|
293
|
+
an update was applied.
|
|
294
|
+
"""
|
|
295
|
+
return StickyIndex.new(self, index, assoc)
|
|
296
|
+
|
|
297
|
+
|
|
269
298
|
def observe_callback(
|
|
270
299
|
callback: Callable[[], None] | Callable[[Any], None] | Callable[[Any, ReadTransaction], None],
|
|
271
300
|
doc: Doc,
|
|
@@ -275,7 +304,10 @@ def observe_callback(
|
|
|
275
304
|
_event = event_types[type(event)](event, doc)
|
|
276
305
|
with doc._read_transaction(event.transaction) as txn:
|
|
277
306
|
params = (_event, txn)
|
|
278
|
-
|
|
307
|
+
try:
|
|
308
|
+
callback(*params[:param_nb]) # type: ignore[arg-type]
|
|
309
|
+
except Exception as exc:
|
|
310
|
+
doc._exceptions.append(exc)
|
|
279
311
|
|
|
280
312
|
|
|
281
313
|
def observe_deep_callback(
|
|
@@ -288,7 +320,10 @@ def observe_deep_callback(
|
|
|
288
320
|
events[idx] = event_types[type(event)](event, doc)
|
|
289
321
|
with doc._read_transaction(event.transaction) as txn:
|
|
290
322
|
params = (events, txn)
|
|
291
|
-
|
|
323
|
+
try:
|
|
324
|
+
callback(*params[:param_nb]) # type: ignore[arg-type]
|
|
325
|
+
except Exception as exc:
|
|
326
|
+
doc._exceptions.append(exc)
|
|
292
327
|
|
|
293
328
|
|
|
294
329
|
class BaseEvent:
|
|
@@ -364,10 +399,12 @@ class Typed:
|
|
|
364
399
|
if key not in annotations:
|
|
365
400
|
raise AttributeError(f'"{type(self).mro()[0]}" has no attribute "{key}"')
|
|
366
401
|
expected_type = annotations[key]
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
402
|
+
origin = get_origin(expected_type)
|
|
403
|
+
if origin in (Union, UnionType):
|
|
404
|
+
expected_types = get_args(expected_type)
|
|
405
|
+
elif origin is not None:
|
|
406
|
+
expected_type = origin
|
|
407
|
+
expected_types = (expected_type,)
|
|
371
408
|
else:
|
|
372
409
|
expected_types = (expected_type,)
|
|
373
410
|
if type(value) not in expected_types:
|
pycrdt/_doc.py
CHANGED
|
@@ -1,7 +1,20 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from functools import partial
|
|
4
|
-
from
|
|
4
|
+
from inspect import iscoroutinefunction
|
|
5
|
+
from typing import (
|
|
6
|
+
Any,
|
|
7
|
+
Awaitable,
|
|
8
|
+
Callable,
|
|
9
|
+
Generic,
|
|
10
|
+
Iterable,
|
|
11
|
+
Literal,
|
|
12
|
+
Type,
|
|
13
|
+
TypeVar,
|
|
14
|
+
Union,
|
|
15
|
+
cast,
|
|
16
|
+
overload,
|
|
17
|
+
)
|
|
5
18
|
|
|
6
19
|
from anyio import BrokenResourceError, create_memory_object_stream
|
|
7
20
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
@@ -10,9 +23,13 @@ from ._base import BaseDoc, BaseType, Typed, base_types, forbid_read_transaction
|
|
|
10
23
|
from ._pycrdt import Doc as _Doc
|
|
11
24
|
from ._pycrdt import SubdocsEvent, Subscription, TransactionEvent
|
|
12
25
|
from ._pycrdt import Transaction as _Transaction
|
|
26
|
+
from ._snapshot import Snapshot
|
|
13
27
|
from ._transaction import NewTransaction, ReadTransaction, Transaction
|
|
14
28
|
|
|
15
29
|
T = TypeVar("T", bound=BaseType)
|
|
30
|
+
TransactionOrSubdocsEvent = TypeVar(
|
|
31
|
+
"TransactionOrSubdocsEvent", bound=TransactionEvent | SubdocsEvent
|
|
32
|
+
)
|
|
16
33
|
|
|
17
34
|
|
|
18
35
|
class Doc(BaseDoc, Generic[T]):
|
|
@@ -30,6 +47,7 @@ class Doc(BaseDoc, Generic[T]):
|
|
|
30
47
|
init: dict[str, T] = {},
|
|
31
48
|
*,
|
|
32
49
|
client_id: int | None = None,
|
|
50
|
+
skip_gc: bool | None = None,
|
|
33
51
|
doc: _Doc | None = None,
|
|
34
52
|
Model=None,
|
|
35
53
|
allow_multithreading: bool = False,
|
|
@@ -38,10 +56,16 @@ class Doc(BaseDoc, Generic[T]):
|
|
|
38
56
|
Args:
|
|
39
57
|
init: The initial root types of the document.
|
|
40
58
|
client_id: An optional client ID for the document.
|
|
59
|
+
skip_gc: Whether to skip garbage collection on deleted collections
|
|
60
|
+
on transaction commit.
|
|
41
61
|
allow_multithreading: Whether to allow the document to be used in different threads.
|
|
42
62
|
"""
|
|
43
63
|
super().__init__(
|
|
44
|
-
client_id=client_id,
|
|
64
|
+
client_id=client_id,
|
|
65
|
+
skip_gc=skip_gc,
|
|
66
|
+
doc=doc,
|
|
67
|
+
Model=Model,
|
|
68
|
+
allow_multithreading=allow_multithreading,
|
|
45
69
|
)
|
|
46
70
|
for k, v in init.items():
|
|
47
71
|
self[k] = v
|
|
@@ -140,7 +164,9 @@ class Doc(BaseDoc, Generic[T]):
|
|
|
140
164
|
Returns:
|
|
141
165
|
The current document state.
|
|
142
166
|
"""
|
|
143
|
-
|
|
167
|
+
with self.transaction() as txn:
|
|
168
|
+
assert txn._txn is not None
|
|
169
|
+
return self._doc.get_state(txn._txn)
|
|
144
170
|
|
|
145
171
|
def get_update(self, state: bytes | None = None) -> bytes:
|
|
146
172
|
"""
|
|
@@ -152,7 +178,9 @@ class Doc(BaseDoc, Generic[T]):
|
|
|
152
178
|
"""
|
|
153
179
|
if state is None:
|
|
154
180
|
state = b"\x00"
|
|
155
|
-
|
|
181
|
+
with self.transaction() as txn:
|
|
182
|
+
assert txn._txn is not None
|
|
183
|
+
return self._doc.get_update(txn._txn, state)
|
|
156
184
|
|
|
157
185
|
def apply_update(self, update: bytes) -> None:
|
|
158
186
|
"""
|
|
@@ -173,6 +201,19 @@ class Doc(BaseDoc, Generic[T]):
|
|
|
173
201
|
assert txn._txn is not None
|
|
174
202
|
self._doc.apply_update(txn._txn, update)
|
|
175
203
|
|
|
204
|
+
@staticmethod
|
|
205
|
+
def from_snapshot(snapshot: "Snapshot", doc: "Doc") -> "Doc":
|
|
206
|
+
"""
|
|
207
|
+
Create a new Doc from a Snapshot and an original Doc.
|
|
208
|
+
Args:
|
|
209
|
+
snapshot: The Snapshot to restore to.
|
|
210
|
+
doc: The original Doc to use for options/state.
|
|
211
|
+
Returns:
|
|
212
|
+
A new Doc instance restored to the snapshot state.
|
|
213
|
+
"""
|
|
214
|
+
new_doc = _Doc.from_snapshot(snapshot._snapshot, doc._doc)
|
|
215
|
+
return Doc(doc=new_doc)
|
|
216
|
+
|
|
176
217
|
def __setitem__(self, key: str, value: T) -> None:
|
|
177
218
|
"""
|
|
178
219
|
Sets a document root type:
|
|
@@ -267,21 +308,44 @@ class Doc(BaseDoc, Generic[T]):
|
|
|
267
308
|
for key, val in self._doc.roots(txn._txn).items()
|
|
268
309
|
}
|
|
269
310
|
|
|
270
|
-
def observe(
|
|
311
|
+
def observe(
|
|
312
|
+
self,
|
|
313
|
+
callback: Callable[[TransactionEvent], None]
|
|
314
|
+
| Callable[[TransactionEvent], Awaitable[None]],
|
|
315
|
+
) -> Subscription:
|
|
271
316
|
"""
|
|
272
317
|
Subscribes a callback to be called with the document change event.
|
|
273
318
|
|
|
274
319
|
Args:
|
|
275
320
|
callback: The callback to call with the [TransactionEvent][pycrdt.TransactionEvent].
|
|
321
|
+
If the callback is async, async transactions must be used.
|
|
276
322
|
|
|
277
323
|
Returns:
|
|
278
324
|
The subscription that can be used to [unobserve()][pycrdt.Doc.unobserve].
|
|
279
325
|
"""
|
|
280
|
-
|
|
326
|
+
if iscoroutinefunction(callback):
|
|
327
|
+
cb = self._async_callback_to_sync(callback)
|
|
328
|
+
else:
|
|
329
|
+
cb = partial(observe_callback, cast(Callable[[TransactionEvent], None], callback), self)
|
|
330
|
+
subscription = self._doc.observe(cb)
|
|
281
331
|
self._subscriptions.append(subscription)
|
|
282
332
|
return subscription
|
|
283
333
|
|
|
284
|
-
def
|
|
334
|
+
def _async_callback_to_sync(
|
|
335
|
+
self,
|
|
336
|
+
async_callback: Callable[[TransactionOrSubdocsEvent], Awaitable[None]],
|
|
337
|
+
) -> Callable[[TransactionOrSubdocsEvent], None]:
|
|
338
|
+
def callback(event: TransactionOrSubdocsEvent) -> None:
|
|
339
|
+
if self._task_group is None:
|
|
340
|
+
raise RuntimeError("Async callback in non-async transaction")
|
|
341
|
+
self._task_group.start_soon(async_callback, event)
|
|
342
|
+
|
|
343
|
+
return callback
|
|
344
|
+
|
|
345
|
+
def observe_subdocs(
|
|
346
|
+
self,
|
|
347
|
+
callback: Callable[[SubdocsEvent], None] | Callable[[SubdocsEvent], Awaitable[None]],
|
|
348
|
+
) -> Subscription:
|
|
285
349
|
"""
|
|
286
350
|
Subscribes a callback to be called with the document subdoc change event.
|
|
287
351
|
|
|
@@ -291,7 +355,11 @@ class Doc(BaseDoc, Generic[T]):
|
|
|
291
355
|
Returns:
|
|
292
356
|
The subscription that can be used to [unobserve()][pycrdt.Doc.unobserve].
|
|
293
357
|
"""
|
|
294
|
-
|
|
358
|
+
if iscoroutinefunction(callback):
|
|
359
|
+
cb = self._async_callback_to_sync(callback)
|
|
360
|
+
else:
|
|
361
|
+
cb = partial(observe_callback, cast(Callable[[SubdocsEvent], None], callback), self)
|
|
362
|
+
subscription = self._doc.observe_subdocs(cb)
|
|
295
363
|
self._subscriptions.append(subscription)
|
|
296
364
|
return subscription
|
|
297
365
|
|
|
@@ -308,21 +376,24 @@ class Doc(BaseDoc, Generic[T]):
|
|
|
308
376
|
@overload
|
|
309
377
|
def events(
|
|
310
378
|
self,
|
|
311
|
-
subdocs: Literal[False],
|
|
379
|
+
subdocs: Literal[False] = False,
|
|
312
380
|
max_buffer_size: float = float("inf"),
|
|
381
|
+
async_transactions: bool = False,
|
|
313
382
|
) -> MemoryObjectReceiveStream[TransactionEvent]: ...
|
|
314
383
|
|
|
315
384
|
@overload
|
|
316
385
|
def events(
|
|
317
386
|
self,
|
|
318
|
-
subdocs: Literal[True],
|
|
387
|
+
subdocs: Literal[True] = True,
|
|
319
388
|
max_buffer_size: float = float("inf"),
|
|
389
|
+
async_transactions: bool = False,
|
|
320
390
|
) -> MemoryObjectReceiveStream[list[SubdocsEvent]]: ...
|
|
321
391
|
|
|
322
392
|
def events(
|
|
323
393
|
self,
|
|
324
394
|
subdocs: bool = False,
|
|
325
395
|
max_buffer_size: float = float("inf"),
|
|
396
|
+
async_transactions: bool = False,
|
|
326
397
|
):
|
|
327
398
|
"""
|
|
328
399
|
Allows to asynchronously iterate over the document events, without using a callback.
|
|
@@ -343,13 +414,21 @@ class Doc(BaseDoc, Generic[T]):
|
|
|
343
414
|
subdocs: Whether to iterate over the [SubdocsEvent][pycrdt.SubdocsEvent] events
|
|
344
415
|
(default is [TransactionEvent][pycrdt.TransactionEvent]).
|
|
345
416
|
max_buffer_size: Maximum number of events that can be buffered.
|
|
417
|
+
async_transactions: Whether async transactions are used for this document,
|
|
418
|
+
in which case iterating over the events can put back-pressure on the
|
|
419
|
+
transactions (don't use an infinite `max_buffer_size` in this case).
|
|
346
420
|
|
|
347
421
|
Returns:
|
|
348
422
|
An async iterator over the document events.
|
|
349
423
|
"""
|
|
350
424
|
observe = self.observe_subdocs if subdocs else self.observe
|
|
351
425
|
if not self._send_streams[subdocs]:
|
|
352
|
-
|
|
426
|
+
if async_transactions:
|
|
427
|
+
self._event_subscription[subdocs] = observe(
|
|
428
|
+
partial(self._async_send_event, subdocs)
|
|
429
|
+
)
|
|
430
|
+
else:
|
|
431
|
+
self._event_subscription[subdocs] = observe(partial(self._send_event, subdocs))
|
|
353
432
|
send_stream, receive_stream = create_memory_object_stream[
|
|
354
433
|
Union[TransactionEvent, SubdocsEvent]
|
|
355
434
|
](max_buffer_size=max_buffer_size)
|
|
@@ -365,6 +444,21 @@ class Doc(BaseDoc, Generic[T]):
|
|
|
365
444
|
except BrokenResourceError:
|
|
366
445
|
to_remove.append(send_stream)
|
|
367
446
|
for send_stream in to_remove:
|
|
447
|
+
send_stream.close()
|
|
448
|
+
send_streams.remove(send_stream)
|
|
449
|
+
if not send_streams:
|
|
450
|
+
self.unobserve(self._event_subscription[subdocs])
|
|
451
|
+
|
|
452
|
+
async def _async_send_event(self, subdocs: bool, event: TransactionEvent | SubdocsEvent):
|
|
453
|
+
to_remove: list[MemoryObjectSendStream[TransactionEvent | SubdocsEvent]] = []
|
|
454
|
+
send_streams = self._send_streams[subdocs]
|
|
455
|
+
for send_stream in send_streams:
|
|
456
|
+
try:
|
|
457
|
+
await send_stream.send(event)
|
|
458
|
+
except BrokenResourceError:
|
|
459
|
+
to_remove.append(send_stream)
|
|
460
|
+
for send_stream in to_remove:
|
|
461
|
+
send_stream.close()
|
|
368
462
|
send_streams.remove(send_stream)
|
|
369
463
|
if not send_streams:
|
|
370
464
|
self.unobserve(self._event_subscription[subdocs])
|
|
@@ -409,4 +503,15 @@ class TypedDoc(Typed):
|
|
|
409
503
|
doc[name] = root_type
|
|
410
504
|
|
|
411
505
|
|
|
506
|
+
def observe_callback(
|
|
507
|
+
callback: Callable[[TransactionEvent], None] | Callable[[SubdocsEvent], None],
|
|
508
|
+
doc: Doc,
|
|
509
|
+
event: Any,
|
|
510
|
+
) -> None:
|
|
511
|
+
try:
|
|
512
|
+
callback(event)
|
|
513
|
+
except Exception as exc:
|
|
514
|
+
doc._exceptions.append(exc)
|
|
515
|
+
|
|
516
|
+
|
|
412
517
|
base_types[_Doc] = Doc
|
pycrdt/_map.py
CHANGED
|
@@ -203,7 +203,9 @@ class Map(BaseType, Generic[T]):
|
|
|
203
203
|
Returns:
|
|
204
204
|
True if the key was found.
|
|
205
205
|
"""
|
|
206
|
-
|
|
206
|
+
# use integrated.has to avoid fetching all keys
|
|
207
|
+
with self.doc.transaction() as txn:
|
|
208
|
+
return self.integrated.has(txn._txn, item)
|
|
207
209
|
|
|
208
210
|
@overload
|
|
209
211
|
def get(self, key: str) -> T | None: ...
|
|
@@ -224,7 +226,7 @@ class Map(BaseType, Generic[T]):
|
|
|
224
226
|
"""
|
|
225
227
|
key, *default_value = args
|
|
226
228
|
with self.doc.transaction():
|
|
227
|
-
if
|
|
229
|
+
if self.__contains__(key):
|
|
228
230
|
return self[key]
|
|
229
231
|
if not default_value:
|
|
230
232
|
return None
|
|
@@ -248,7 +250,7 @@ class Map(BaseType, Generic[T]):
|
|
|
248
250
|
"""
|
|
249
251
|
key, *default_value = args
|
|
250
252
|
with self.doc.transaction():
|
|
251
|
-
if
|
|
253
|
+
if not self.__contains__(key):
|
|
252
254
|
if not default_value:
|
|
253
255
|
raise KeyError
|
|
254
256
|
return default_value[0]
|
|
@@ -261,7 +263,7 @@ class Map(BaseType, Generic[T]):
|
|
|
261
263
|
def _check_key(self, key: str) -> None:
|
|
262
264
|
if not isinstance(key, str):
|
|
263
265
|
raise RuntimeError("Key must be of type string")
|
|
264
|
-
if
|
|
266
|
+
if not self.__contains__(key):
|
|
265
267
|
raise KeyError(key)
|
|
266
268
|
|
|
267
269
|
def keys(self) -> Iterable[str]:
|
pycrdt/_provider.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
|
4
|
+
from logging import Logger, getLogger
|
|
5
|
+
from typing import AsyncIterator, Protocol
|
|
6
|
+
|
|
7
|
+
from anyio import (
|
|
8
|
+
TASK_STATUS_IGNORED,
|
|
9
|
+
Event,
|
|
10
|
+
Lock,
|
|
11
|
+
create_task_group,
|
|
12
|
+
)
|
|
13
|
+
from anyio.abc import TaskGroup, TaskStatus
|
|
14
|
+
|
|
15
|
+
from ._doc import Doc
|
|
16
|
+
from ._sync import (
|
|
17
|
+
YMessageType,
|
|
18
|
+
YSyncMessageType,
|
|
19
|
+
create_sync_message,
|
|
20
|
+
create_update_message,
|
|
21
|
+
handle_sync_message,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Channel(Protocol):
|
|
26
|
+
"""A transport-agnostic stream used to synchronize a document through a provider.
|
|
27
|
+
An example of a channel is a WebSocket.
|
|
28
|
+
|
|
29
|
+
Messages can be received through the channel using an async iterator,
|
|
30
|
+
until the connection is closed:
|
|
31
|
+
```py
|
|
32
|
+
async for message in channel:
|
|
33
|
+
...
|
|
34
|
+
```
|
|
35
|
+
Or directly by calling `recv()`:
|
|
36
|
+
```py
|
|
37
|
+
message = await channel.recv()
|
|
38
|
+
```
|
|
39
|
+
Sending messages is done with `send()`:
|
|
40
|
+
```py
|
|
41
|
+
await channel.send(message)
|
|
42
|
+
```
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def path(self) -> str:
|
|
47
|
+
"""The channel path."""
|
|
48
|
+
... # pragma: nocover
|
|
49
|
+
|
|
50
|
+
def __aiter__(self) -> "Channel":
|
|
51
|
+
return self
|
|
52
|
+
|
|
53
|
+
async def __anext__(self) -> bytes:
|
|
54
|
+
return await self.recv()
|
|
55
|
+
|
|
56
|
+
async def send(self, message: bytes) -> None:
|
|
57
|
+
"""Send a message.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
message: The message to send.
|
|
61
|
+
"""
|
|
62
|
+
... # pragma: nocover
|
|
63
|
+
|
|
64
|
+
async def recv(self) -> bytes:
|
|
65
|
+
"""Receive a message.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
The received message.
|
|
69
|
+
"""
|
|
70
|
+
... # pragma: nocover
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class Provider:
|
|
74
|
+
def __init__(self, doc: Doc, channel: Channel, log: Logger | None = None) -> None:
|
|
75
|
+
"""A provider synchronizes a document through a channel.
|
|
76
|
+
|
|
77
|
+
The provider should preferably be used with an async context manager:
|
|
78
|
+
```py
|
|
79
|
+
async with provider:
|
|
80
|
+
...
|
|
81
|
+
```
|
|
82
|
+
However, a lower-level API can also be used:
|
|
83
|
+
```py
|
|
84
|
+
task = asyncio.create_task(provider.start())
|
|
85
|
+
await provider.started.wait()
|
|
86
|
+
...
|
|
87
|
+
await provider.stop()
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
Arguments:
|
|
91
|
+
doc: The `Doc` to connect through the `Channel`.
|
|
92
|
+
channel: The `Channel` through which to connect the `Doc`.
|
|
93
|
+
log: An optional logger.
|
|
94
|
+
"""
|
|
95
|
+
self._doc = doc
|
|
96
|
+
self._channel = channel
|
|
97
|
+
self.log = log or getLogger(__name__)
|
|
98
|
+
self.started = Event()
|
|
99
|
+
self._start_lock = Lock()
|
|
100
|
+
self._task_group: TaskGroup | None = None
|
|
101
|
+
|
|
102
|
+
async def _run(self):
|
|
103
|
+
sync_message = create_sync_message(self._doc)
|
|
104
|
+
self.log.debug(
|
|
105
|
+
"Sending %s message to endpoint: %s",
|
|
106
|
+
YSyncMessageType.SYNC_STEP1.name,
|
|
107
|
+
self._channel.path,
|
|
108
|
+
)
|
|
109
|
+
await self._channel.send(sync_message)
|
|
110
|
+
assert self._task_group is not None
|
|
111
|
+
self._task_group.start_soon(self._send_updates)
|
|
112
|
+
async for message in self._channel:
|
|
113
|
+
if message[0] == YMessageType.SYNC:
|
|
114
|
+
self.log.debug(
|
|
115
|
+
"Received %s message from endpoint: %s",
|
|
116
|
+
YSyncMessageType(message[1]).name,
|
|
117
|
+
self._channel.path,
|
|
118
|
+
)
|
|
119
|
+
reply = handle_sync_message(message[1:], self._doc)
|
|
120
|
+
if reply is not None:
|
|
121
|
+
self.log.debug(
|
|
122
|
+
"Sending %s message to endpoint: %s",
|
|
123
|
+
YSyncMessageType.SYNC_STEP2.name,
|
|
124
|
+
self._channel.path,
|
|
125
|
+
)
|
|
126
|
+
await self._channel.send(reply)
|
|
127
|
+
|
|
128
|
+
async def _send_updates(self):
|
|
129
|
+
async with self._doc.events() as events:
|
|
130
|
+
async for event in events:
|
|
131
|
+
message = create_update_message(event.update)
|
|
132
|
+
await self._channel.send(message)
|
|
133
|
+
|
|
134
|
+
async def __aenter__(self) -> Provider:
|
|
135
|
+
async with AsyncExitStack() as exit_stack:
|
|
136
|
+
self._task_group = await exit_stack.enter_async_context(
|
|
137
|
+
self._get_or_create_task_group()
|
|
138
|
+
)
|
|
139
|
+
await self._task_group.start(self.start)
|
|
140
|
+
self._exit_stack = exit_stack.pop_all()
|
|
141
|
+
return self
|
|
142
|
+
|
|
143
|
+
async def __aexit__(self, exc_type, exc_value, exc_tb):
|
|
144
|
+
await self.stop()
|
|
145
|
+
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)
|
|
146
|
+
|
|
147
|
+
@asynccontextmanager
|
|
148
|
+
async def _get_or_create_task_group(self) -> AsyncIterator[TaskGroup]:
|
|
149
|
+
if self._task_group is not None:
|
|
150
|
+
yield self._task_group
|
|
151
|
+
return
|
|
152
|
+
|
|
153
|
+
async with create_task_group() as tg:
|
|
154
|
+
yield tg
|
|
155
|
+
|
|
156
|
+
async def start(
|
|
157
|
+
self,
|
|
158
|
+
*,
|
|
159
|
+
task_status: TaskStatus[None] = TASK_STATUS_IGNORED,
|
|
160
|
+
) -> None:
|
|
161
|
+
"""Start the provider.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
task_status: The status to set when the task has started.
|
|
165
|
+
"""
|
|
166
|
+
async with self._start_lock:
|
|
167
|
+
async with self._get_or_create_task_group() as self._task_group:
|
|
168
|
+
task_status.started()
|
|
169
|
+
self.started.set()
|
|
170
|
+
self._task_group.start_soon(self._run)
|
|
171
|
+
|
|
172
|
+
async def stop(self) -> None:
|
|
173
|
+
"""Stop the provider."""
|
|
174
|
+
assert self._task_group is not None
|
|
175
|
+
self._task_group.cancel_scope.cancel()
|
|
176
|
+
self._task_group = None
|
|
Binary file
|