pycrdt 0.12.12__cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl → 0.12.45__cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pycrdt might be problematic. Click here for more details.

pycrdt/__init__.py CHANGED
@@ -1,16 +1,27 @@
1
+ from pkgutil import extend_path
2
+
3
+ __path__ = extend_path(__path__, __name__)
4
+
1
5
  from ._array import Array as Array
2
6
  from ._array import ArrayEvent as ArrayEvent
3
7
  from ._array import TypedArray as TypedArray
4
8
  from ._awareness import Awareness as Awareness
9
+ from ._awareness import is_awareness_disconnect_message as is_awareness_disconnect_message
5
10
  from ._doc import Doc as Doc
6
11
  from ._doc import TypedDoc as TypedDoc
7
12
  from ._map import Map as Map
8
13
  from ._map import MapEvent as MapEvent
9
14
  from ._map import TypedMap as TypedMap
15
+ from ._provider import Channel as Channel
16
+ from ._provider import Provider as Provider
17
+ from ._pycrdt import DeleteSet as DeleteSet
10
18
  from ._pycrdt import StackItem as StackItem
11
19
  from ._pycrdt import SubdocsEvent as SubdocsEvent
12
20
  from ._pycrdt import Subscription as Subscription
13
21
  from ._pycrdt import TransactionEvent as TransactionEvent
22
+ from ._snapshot import Snapshot as Snapshot
23
+ from ._sticky_index import Assoc as Assoc
24
+ from ._sticky_index import StickyIndex as StickyIndex
14
25
  from ._sync import Decoder as Decoder
15
26
  from ._sync import Encoder as Encoder
16
27
  from ._sync import YMessageType as YMessageType
pycrdt/_array.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, cast, overload
4
4
 
5
- from ._base import BaseDoc, BaseEvent, BaseType, Typed, base_types, event_types
5
+ from ._base import BaseDoc, BaseEvent, BaseType, Sequence, Typed, base_types, event_types
6
6
  from ._pycrdt import Array as _Array
7
7
  from ._pycrdt import ArrayEvent as _ArrayEvent
8
8
  from ._pycrdt import Subscription
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
13
13
  T = TypeVar("T")
14
14
 
15
15
 
16
- class Array(BaseType, Generic[T]):
16
+ class Array(Sequence, Generic[T]):
17
17
  """
18
18
  A collection used to store data in an indexed sequence structure, similar to a Python `list`.
19
19
  """
@@ -399,7 +399,7 @@ class ArrayIterator:
399
399
  self.idx = 0
400
400
 
401
401
  def __iter__(self) -> ArrayIterator:
402
- return self
402
+ return self # pragma: nocover
403
403
 
404
404
  def __next__(self) -> Any:
405
405
  if self.idx == self.length:
pycrdt/_awareness.py CHANGED
@@ -10,7 +10,7 @@ from anyio import TASK_STATUS_IGNORED, create_task_group, sleep
10
10
  from anyio.abc import TaskGroup, TaskStatus
11
11
 
12
12
  from ._doc import Doc
13
- from ._sync import Decoder, Encoder
13
+ from ._sync import Decoder, Encoder, read_message
14
14
 
15
15
 
16
16
  class Awareness:
@@ -278,3 +278,27 @@ class Awareness:
278
278
  id: The subscription ID to unregister.
279
279
  """
280
280
  del self._subscriptions[id]
281
+
282
+
283
+ def is_awareness_disconnect_message(message: bytes) -> bool:
284
+ """
285
+ Check if the message is null, which means that it is a disconnection message
286
+ from the client.
287
+
288
+ Args:
289
+ message: The message received from the client.
290
+
291
+ Returns:
292
+ Whether the message is a disconnection message or not.
293
+ """
294
+ decoder = Decoder(read_message(message))
295
+ length = decoder.read_var_uint()
296
+ # A disconnection message should be a single message
297
+ if length == 1:
298
+ # Remove client_id and clock information from message (not used)
299
+ for _ in range(2):
300
+ decoder.read_var_uint()
301
+ state = decoder.read_var_string()
302
+ if state == "null":
303
+ return True
304
+ return False
pycrdt/_base.py CHANGED
@@ -4,6 +4,7 @@ import threading
4
4
  from abc import ABC, abstractmethod
5
5
  from functools import lru_cache, partial
6
6
  from inspect import signature
7
+ from types import UnionType
7
8
  from typing import (
8
9
  TYPE_CHECKING,
9
10
  Any,
@@ -12,17 +13,21 @@ from typing import (
12
13
  Type,
13
14
  Union,
14
15
  cast,
16
+ get_args,
17
+ get_origin,
15
18
  get_type_hints,
16
19
  overload,
17
20
  )
18
21
 
19
22
  import anyio
20
23
  from anyio import BrokenResourceError, create_memory_object_stream
24
+ from anyio.abc import TaskGroup
21
25
  from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
22
26
 
23
27
  from ._pycrdt import Doc as _Doc
24
28
  from ._pycrdt import Subscription
25
29
  from ._pycrdt import Transaction as _Transaction
30
+ from ._sticky_index import Assoc, StickyIndex
26
31
  from ._transaction import ReadTransaction, Transaction
27
32
 
28
33
  if TYPE_CHECKING:
@@ -42,17 +47,20 @@ class BaseDoc:
42
47
  _doc: _Doc
43
48
  _twin_doc: BaseDoc | None
44
49
  _txn: Transaction | None
50
+ _exceptions: list[Exception]
45
51
  _txn_lock: threading.Lock
46
52
  _txn_async_lock: anyio.Lock
47
53
  _allow_multithreading: bool
48
54
  _Model: Any
49
55
  _subscriptions: list[Subscription]
50
56
  _origins: dict[int, Any]
57
+ _task_group: TaskGroup | None
51
58
 
52
59
  def __init__(
53
60
  self,
54
61
  *,
55
62
  client_id: int | None = None,
63
+ skip_gc: bool | None = None,
56
64
  doc: _Doc | None = None,
57
65
  Model=None,
58
66
  allow_multithreading: bool = False,
@@ -60,15 +68,17 @@ class BaseDoc:
60
68
  ) -> None:
61
69
  super().__init__(**data)
62
70
  if doc is None:
63
- doc = _Doc(client_id)
71
+ doc = _Doc(client_id, skip_gc)
64
72
  self._doc = doc
65
73
  self._txn = None
74
+ self._exceptions = []
66
75
  self._txn_lock = threading.Lock()
67
76
  self._txn_async_lock = anyio.Lock()
68
77
  self._Model = Model
69
78
  self._subscriptions = []
70
79
  self._origins = {}
71
80
  self._allow_multithreading = allow_multithreading
81
+ self._task_group = None
72
82
 
73
83
 
74
84
  class BaseType(ABC):
@@ -261,11 +271,30 @@ class BaseType(ABC):
261
271
  except BrokenResourceError:
262
272
  to_remove.append(send_stream)
263
273
  for send_stream in to_remove:
274
+ send_stream.close()
264
275
  send_streams.remove(send_stream)
265
276
  if not send_streams:
266
277
  self.unobserve(self._event_subscription[deep])
267
278
 
268
279
 
280
+ class Sequence(BaseType):
281
+ def sticky_index(self, index: int, assoc: Assoc = Assoc.AFTER) -> StickyIndex:
282
+ """
283
+ A permanent position that sticks to the same place even when
284
+ concurrent updates are made.
285
+
286
+ Args:
287
+ index: The index at which to stick.
288
+ assoc: The [Assoc][pycrdt.Assoc] specifying whether to stick to the location
289
+ before or after the index.
290
+
291
+ Returns:
292
+ A [StickyIndex][pycrdt.StickyIndex] that can be used to retrieve the index after
293
+ an update was applied.
294
+ """
295
+ return StickyIndex.new(self, index, assoc)
296
+
297
+
269
298
  def observe_callback(
270
299
  callback: Callable[[], None] | Callable[[Any], None] | Callable[[Any, ReadTransaction], None],
271
300
  doc: Doc,
@@ -275,7 +304,10 @@ def observe_callback(
275
304
  _event = event_types[type(event)](event, doc)
276
305
  with doc._read_transaction(event.transaction) as txn:
277
306
  params = (_event, txn)
278
- callback(*params[:param_nb]) # type: ignore[arg-type]
307
+ try:
308
+ callback(*params[:param_nb]) # type: ignore[arg-type]
309
+ except Exception as exc:
310
+ doc._exceptions.append(exc)
279
311
 
280
312
 
281
313
  def observe_deep_callback(
@@ -288,7 +320,10 @@ def observe_deep_callback(
288
320
  events[idx] = event_types[type(event)](event, doc)
289
321
  with doc._read_transaction(event.transaction) as txn:
290
322
  params = (events, txn)
291
- callback(*params[:param_nb]) # type: ignore[arg-type]
323
+ try:
324
+ callback(*params[:param_nb]) # type: ignore[arg-type]
325
+ except Exception as exc:
326
+ doc._exceptions.append(exc)
292
327
 
293
328
 
294
329
  class BaseEvent:
@@ -364,10 +399,12 @@ class Typed:
364
399
  if key not in annotations:
365
400
  raise AttributeError(f'"{type(self).mro()[0]}" has no attribute "{key}"')
366
401
  expected_type = annotations[key]
367
- if hasattr(expected_type, "__origin__"):
368
- expected_type = expected_type.__origin__
369
- if hasattr(expected_type, "__args__"):
370
- expected_types = expected_type.__args__
402
+ origin = get_origin(expected_type)
403
+ if origin in (Union, UnionType):
404
+ expected_types = get_args(expected_type)
405
+ elif origin is not None:
406
+ expected_type = origin
407
+ expected_types = (expected_type,)
371
408
  else:
372
409
  expected_types = (expected_type,)
373
410
  if type(value) not in expected_types:
pycrdt/_doc.py CHANGED
@@ -1,7 +1,20 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from functools import partial
4
- from typing import Any, Callable, Generic, Iterable, Literal, Type, TypeVar, Union, cast, overload
4
+ from inspect import iscoroutinefunction
5
+ from typing import (
6
+ Any,
7
+ Awaitable,
8
+ Callable,
9
+ Generic,
10
+ Iterable,
11
+ Literal,
12
+ Type,
13
+ TypeVar,
14
+ Union,
15
+ cast,
16
+ overload,
17
+ )
5
18
 
6
19
  from anyio import BrokenResourceError, create_memory_object_stream
7
20
  from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -10,9 +23,13 @@ from ._base import BaseDoc, BaseType, Typed, base_types, forbid_read_transaction
10
23
  from ._pycrdt import Doc as _Doc
11
24
  from ._pycrdt import SubdocsEvent, Subscription, TransactionEvent
12
25
  from ._pycrdt import Transaction as _Transaction
26
+ from ._snapshot import Snapshot
13
27
  from ._transaction import NewTransaction, ReadTransaction, Transaction
14
28
 
15
29
  T = TypeVar("T", bound=BaseType)
30
+ TransactionOrSubdocsEvent = TypeVar(
31
+ "TransactionOrSubdocsEvent", bound=TransactionEvent | SubdocsEvent
32
+ )
16
33
 
17
34
 
18
35
  class Doc(BaseDoc, Generic[T]):
@@ -30,6 +47,7 @@ class Doc(BaseDoc, Generic[T]):
30
47
  init: dict[str, T] = {},
31
48
  *,
32
49
  client_id: int | None = None,
50
+ skip_gc: bool | None = None,
33
51
  doc: _Doc | None = None,
34
52
  Model=None,
35
53
  allow_multithreading: bool = False,
@@ -38,10 +56,16 @@ class Doc(BaseDoc, Generic[T]):
38
56
  Args:
39
57
  init: The initial root types of the document.
40
58
  client_id: An optional client ID for the document.
59
+ skip_gc: Whether to skip garbage collection on deleted collections
60
+ on transaction commit.
41
61
  allow_multithreading: Whether to allow the document to be used in different threads.
42
62
  """
43
63
  super().__init__(
44
- client_id=client_id, doc=doc, Model=Model, allow_multithreading=allow_multithreading
64
+ client_id=client_id,
65
+ skip_gc=skip_gc,
66
+ doc=doc,
67
+ Model=Model,
68
+ allow_multithreading=allow_multithreading,
45
69
  )
46
70
  for k, v in init.items():
47
71
  self[k] = v
@@ -140,7 +164,9 @@ class Doc(BaseDoc, Generic[T]):
140
164
  Returns:
141
165
  The current document state.
142
166
  """
143
- return self._doc.get_state()
167
+ with self.transaction() as txn:
168
+ assert txn._txn is not None
169
+ return self._doc.get_state(txn._txn)
144
170
 
145
171
  def get_update(self, state: bytes | None = None) -> bytes:
146
172
  """
@@ -152,7 +178,9 @@ class Doc(BaseDoc, Generic[T]):
152
178
  """
153
179
  if state is None:
154
180
  state = b"\x00"
155
- return self._doc.get_update(state)
181
+ with self.transaction() as txn:
182
+ assert txn._txn is not None
183
+ return self._doc.get_update(txn._txn, state)
156
184
 
157
185
  def apply_update(self, update: bytes) -> None:
158
186
  """
@@ -173,6 +201,19 @@ class Doc(BaseDoc, Generic[T]):
173
201
  assert txn._txn is not None
174
202
  self._doc.apply_update(txn._txn, update)
175
203
 
204
+ @staticmethod
205
+ def from_snapshot(snapshot: "Snapshot", doc: "Doc") -> "Doc":
206
+ """
207
+ Create a new Doc from a Snapshot and an original Doc.
208
+ Args:
209
+ snapshot: The Snapshot to restore to.
210
+ doc: The original Doc to use for options/state.
211
+ Returns:
212
+ A new Doc instance restored to the snapshot state.
213
+ """
214
+ new_doc = _Doc.from_snapshot(snapshot._snapshot, doc._doc)
215
+ return Doc(doc=new_doc)
216
+
176
217
  def __setitem__(self, key: str, value: T) -> None:
177
218
  """
178
219
  Sets a document root type:
@@ -267,21 +308,44 @@ class Doc(BaseDoc, Generic[T]):
267
308
  for key, val in self._doc.roots(txn._txn).items()
268
309
  }
269
310
 
270
- def observe(self, callback: Callable[[TransactionEvent], None]) -> Subscription:
311
+ def observe(
312
+ self,
313
+ callback: Callable[[TransactionEvent], None]
314
+ | Callable[[TransactionEvent], Awaitable[None]],
315
+ ) -> Subscription:
271
316
  """
272
317
  Subscribes a callback to be called with the document change event.
273
318
 
274
319
  Args:
275
320
  callback: The callback to call with the [TransactionEvent][pycrdt.TransactionEvent].
321
+ If the callback is async, async transactions must be used.
276
322
 
277
323
  Returns:
278
324
  The subscription that can be used to [unobserve()][pycrdt.Doc.unobserve].
279
325
  """
280
- subscription = self._doc.observe(callback)
326
+ if iscoroutinefunction(callback):
327
+ cb = self._async_callback_to_sync(callback)
328
+ else:
329
+ cb = partial(observe_callback, cast(Callable[[TransactionEvent], None], callback), self)
330
+ subscription = self._doc.observe(cb)
281
331
  self._subscriptions.append(subscription)
282
332
  return subscription
283
333
 
284
- def observe_subdocs(self, callback: Callable[[SubdocsEvent], None]) -> Subscription:
334
+ def _async_callback_to_sync(
335
+ self,
336
+ async_callback: Callable[[TransactionOrSubdocsEvent], Awaitable[None]],
337
+ ) -> Callable[[TransactionOrSubdocsEvent], None]:
338
+ def callback(event: TransactionOrSubdocsEvent) -> None:
339
+ if self._task_group is None:
340
+ raise RuntimeError("Async callback in non-async transaction")
341
+ self._task_group.start_soon(async_callback, event)
342
+
343
+ return callback
344
+
345
+ def observe_subdocs(
346
+ self,
347
+ callback: Callable[[SubdocsEvent], None] | Callable[[SubdocsEvent], Awaitable[None]],
348
+ ) -> Subscription:
285
349
  """
286
350
  Subscribes a callback to be called with the document subdoc change event.
287
351
 
@@ -291,7 +355,11 @@ class Doc(BaseDoc, Generic[T]):
291
355
  Returns:
292
356
  The subscription that can be used to [unobserve()][pycrdt.Doc.unobserve].
293
357
  """
294
- subscription = self._doc.observe_subdocs(callback)
358
+ if iscoroutinefunction(callback):
359
+ cb = self._async_callback_to_sync(callback)
360
+ else:
361
+ cb = partial(observe_callback, cast(Callable[[SubdocsEvent], None], callback), self)
362
+ subscription = self._doc.observe_subdocs(cb)
295
363
  self._subscriptions.append(subscription)
296
364
  return subscription
297
365
 
@@ -308,21 +376,24 @@ class Doc(BaseDoc, Generic[T]):
308
376
  @overload
309
377
  def events(
310
378
  self,
311
- subdocs: Literal[False],
379
+ subdocs: Literal[False] = False,
312
380
  max_buffer_size: float = float("inf"),
381
+ async_transactions: bool = False,
313
382
  ) -> MemoryObjectReceiveStream[TransactionEvent]: ...
314
383
 
315
384
  @overload
316
385
  def events(
317
386
  self,
318
- subdocs: Literal[True],
387
+ subdocs: Literal[True] = True,
319
388
  max_buffer_size: float = float("inf"),
389
+ async_transactions: bool = False,
320
390
  ) -> MemoryObjectReceiveStream[list[SubdocsEvent]]: ...
321
391
 
322
392
  def events(
323
393
  self,
324
394
  subdocs: bool = False,
325
395
  max_buffer_size: float = float("inf"),
396
+ async_transactions: bool = False,
326
397
  ):
327
398
  """
328
399
  Allows to asynchronously iterate over the document events, without using a callback.
@@ -343,13 +414,21 @@ class Doc(BaseDoc, Generic[T]):
343
414
  subdocs: Whether to iterate over the [SubdocsEvent][pycrdt.SubdocsEvent] events
344
415
  (default is [TransactionEvent][pycrdt.TransactionEvent]).
345
416
  max_buffer_size: Maximum number of events that can be buffered.
417
+ async_transactions: Whether async transactions are used for this document,
418
+ in which case iterating over the events can put back-pressure on the
419
+ transactions (don't use an infinite `max_buffer_size` in this case).
346
420
 
347
421
  Returns:
348
422
  An async iterator over the document events.
349
423
  """
350
424
  observe = self.observe_subdocs if subdocs else self.observe
351
425
  if not self._send_streams[subdocs]:
352
- self._event_subscription[subdocs] = observe(partial(self._send_event, subdocs))
426
+ if async_transactions:
427
+ self._event_subscription[subdocs] = observe(
428
+ partial(self._async_send_event, subdocs)
429
+ )
430
+ else:
431
+ self._event_subscription[subdocs] = observe(partial(self._send_event, subdocs))
353
432
  send_stream, receive_stream = create_memory_object_stream[
354
433
  Union[TransactionEvent, SubdocsEvent]
355
434
  ](max_buffer_size=max_buffer_size)
@@ -365,6 +444,21 @@ class Doc(BaseDoc, Generic[T]):
365
444
  except BrokenResourceError:
366
445
  to_remove.append(send_stream)
367
446
  for send_stream in to_remove:
447
+ send_stream.close()
448
+ send_streams.remove(send_stream)
449
+ if not send_streams:
450
+ self.unobserve(self._event_subscription[subdocs])
451
+
452
+ async def _async_send_event(self, subdocs: bool, event: TransactionEvent | SubdocsEvent):
453
+ to_remove: list[MemoryObjectSendStream[TransactionEvent | SubdocsEvent]] = []
454
+ send_streams = self._send_streams[subdocs]
455
+ for send_stream in send_streams:
456
+ try:
457
+ await send_stream.send(event)
458
+ except BrokenResourceError:
459
+ to_remove.append(send_stream)
460
+ for send_stream in to_remove:
461
+ send_stream.close()
368
462
  send_streams.remove(send_stream)
369
463
  if not send_streams:
370
464
  self.unobserve(self._event_subscription[subdocs])
@@ -409,4 +503,15 @@ class TypedDoc(Typed):
409
503
  doc[name] = root_type
410
504
 
411
505
 
506
+ def observe_callback(
507
+ callback: Callable[[TransactionEvent], None] | Callable[[SubdocsEvent], None],
508
+ doc: Doc,
509
+ event: Any,
510
+ ) -> None:
511
+ try:
512
+ callback(event)
513
+ except Exception as exc:
514
+ doc._exceptions.append(exc)
515
+
516
+
412
517
  base_types[_Doc] = Doc
pycrdt/_map.py CHANGED
@@ -203,7 +203,9 @@ class Map(BaseType, Generic[T]):
203
203
  Returns:
204
204
  True if the key was found.
205
205
  """
206
- return item in self.keys()
206
+ # use integrated.has to avoid fetching all keys
207
+ with self.doc.transaction() as txn:
208
+ return self.integrated.has(txn._txn, item)
207
209
 
208
210
  @overload
209
211
  def get(self, key: str) -> T | None: ...
@@ -224,7 +226,7 @@ class Map(BaseType, Generic[T]):
224
226
  """
225
227
  key, *default_value = args
226
228
  with self.doc.transaction():
227
- if key in self.keys():
229
+ if self.__contains__(key):
228
230
  return self[key]
229
231
  if not default_value:
230
232
  return None
@@ -248,7 +250,7 @@ class Map(BaseType, Generic[T]):
248
250
  """
249
251
  key, *default_value = args
250
252
  with self.doc.transaction():
251
- if key not in self.keys():
253
+ if not self.__contains__(key):
252
254
  if not default_value:
253
255
  raise KeyError
254
256
  return default_value[0]
@@ -261,7 +263,7 @@ class Map(BaseType, Generic[T]):
261
263
  def _check_key(self, key: str) -> None:
262
264
  if not isinstance(key, str):
263
265
  raise RuntimeError("Key must be of type string")
264
- if key not in self.keys():
266
+ if not self.__contains__(key):
265
267
  raise KeyError(key)
266
268
 
267
269
  def keys(self) -> Iterable[str]:
pycrdt/_provider.py ADDED
@@ -0,0 +1,176 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import AsyncExitStack, asynccontextmanager
4
+ from logging import Logger, getLogger
5
+ from typing import AsyncIterator, Protocol
6
+
7
+ from anyio import (
8
+ TASK_STATUS_IGNORED,
9
+ Event,
10
+ Lock,
11
+ create_task_group,
12
+ )
13
+ from anyio.abc import TaskGroup, TaskStatus
14
+
15
+ from ._doc import Doc
16
+ from ._sync import (
17
+ YMessageType,
18
+ YSyncMessageType,
19
+ create_sync_message,
20
+ create_update_message,
21
+ handle_sync_message,
22
+ )
23
+
24
+
25
+ class Channel(Protocol):
26
+ """A transport-agnostic stream used to synchronize a document through a provider.
27
+ An example of a channel is a WebSocket.
28
+
29
+ Messages can be received through the channel using an async iterator,
30
+ until the connection is closed:
31
+ ```py
32
+ async for message in channel:
33
+ ...
34
+ ```
35
+ Or directly by calling `recv()`:
36
+ ```py
37
+ message = await channel.recv()
38
+ ```
39
+ Sending messages is done with `send()`:
40
+ ```py
41
+ await channel.send(message)
42
+ ```
43
+ """
44
+
45
+ @property
46
+ def path(self) -> str:
47
+ """The channel path."""
48
+ ... # pragma: nocover
49
+
50
+ def __aiter__(self) -> "Channel":
51
+ return self
52
+
53
+ async def __anext__(self) -> bytes:
54
+ return await self.recv()
55
+
56
+ async def send(self, message: bytes) -> None:
57
+ """Send a message.
58
+
59
+ Args:
60
+ message: The message to send.
61
+ """
62
+ ... # pragma: nocover
63
+
64
+ async def recv(self) -> bytes:
65
+ """Receive a message.
66
+
67
+ Returns:
68
+ The received message.
69
+ """
70
+ ... # pragma: nocover
71
+
72
+
73
+ class Provider:
74
+ def __init__(self, doc: Doc, channel: Channel, log: Logger | None = None) -> None:
75
+ """A provider synchronizes a document through a channel.
76
+
77
+ The provider should preferably be used with an async context manager:
78
+ ```py
79
+ async with provider:
80
+ ...
81
+ ```
82
+ However, a lower-level API can also be used:
83
+ ```py
84
+ task = asyncio.create_task(provider.start())
85
+ await provider.started.wait()
86
+ ...
87
+ await provider.stop()
88
+ ```
89
+
90
+ Arguments:
91
+ doc: The `Doc` to connect through the `Channel`.
92
+ channel: The `Channel` through which to connect the `Doc`.
93
+ log: An optional logger.
94
+ """
95
+ self._doc = doc
96
+ self._channel = channel
97
+ self.log = log or getLogger(__name__)
98
+ self.started = Event()
99
+ self._start_lock = Lock()
100
+ self._task_group: TaskGroup | None = None
101
+
102
+ async def _run(self):
103
+ sync_message = create_sync_message(self._doc)
104
+ self.log.debug(
105
+ "Sending %s message to endpoint: %s",
106
+ YSyncMessageType.SYNC_STEP1.name,
107
+ self._channel.path,
108
+ )
109
+ await self._channel.send(sync_message)
110
+ assert self._task_group is not None
111
+ self._task_group.start_soon(self._send_updates)
112
+ async for message in self._channel:
113
+ if message[0] == YMessageType.SYNC:
114
+ self.log.debug(
115
+ "Received %s message from endpoint: %s",
116
+ YSyncMessageType(message[1]).name,
117
+ self._channel.path,
118
+ )
119
+ reply = handle_sync_message(message[1:], self._doc)
120
+ if reply is not None:
121
+ self.log.debug(
122
+ "Sending %s message to endpoint: %s",
123
+ YSyncMessageType.SYNC_STEP2.name,
124
+ self._channel.path,
125
+ )
126
+ await self._channel.send(reply)
127
+
128
+ async def _send_updates(self):
129
+ async with self._doc.events() as events:
130
+ async for event in events:
131
+ message = create_update_message(event.update)
132
+ await self._channel.send(message)
133
+
134
+ async def __aenter__(self) -> Provider:
135
+ async with AsyncExitStack() as exit_stack:
136
+ self._task_group = await exit_stack.enter_async_context(
137
+ self._get_or_create_task_group()
138
+ )
139
+ await self._task_group.start(self.start)
140
+ self._exit_stack = exit_stack.pop_all()
141
+ return self
142
+
143
+ async def __aexit__(self, exc_type, exc_value, exc_tb):
144
+ await self.stop()
145
+ return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)
146
+
147
+ @asynccontextmanager
148
+ async def _get_or_create_task_group(self) -> AsyncIterator[TaskGroup]:
149
+ if self._task_group is not None:
150
+ yield self._task_group
151
+ return
152
+
153
+ async with create_task_group() as tg:
154
+ yield tg
155
+
156
+ async def start(
157
+ self,
158
+ *,
159
+ task_status: TaskStatus[None] = TASK_STATUS_IGNORED,
160
+ ) -> None:
161
+ """Start the provider.
162
+
163
+ Args:
164
+ task_status: The status to set when the task has started.
165
+ """
166
+ async with self._start_lock:
167
+ async with self._get_or_create_task_group() as self._task_group:
168
+ task_status.started()
169
+ self.started.set()
170
+ self._task_group.start_soon(self._run)
171
+
172
+ async def stop(self) -> None:
173
+ """Stop the provider."""
174
+ assert self._task_group is not None
175
+ self._task_group.cancel_scope.cancel()
176
+ self._task_group = None