sockudo-python 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2465 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import base64
5
+ import json
6
+ import struct
7
+ import urllib.parse
8
+ from collections import OrderedDict
9
+ from dataclasses import dataclass, field
10
+ from enum import Enum
11
+ from typing import (
12
+ Any,
13
+ Awaitable,
14
+ Callable,
15
+ Dict,
16
+ List,
17
+ Optional,
18
+ Tuple,
19
+ Union,
20
+ )
21
+
22
+ import httpx
23
+ import msgpack
24
+ import vcdiff_decoder
25
+ from nacl.secret import SecretBox
26
+ from nacl.utils import random as nacl_random
27
+ from websockets.asyncio.client import connect as ws_connect
28
+ from websockets.exceptions import ConnectionClosed
29
+
30
+
31
+ class SockudoException(RuntimeError):
32
+ pass
33
+
34
+
35
+ class InvalidAppKey(SockudoException):
36
+ pass
37
+
38
+
39
+ class InvalidOptions(SockudoException):
40
+ pass
41
+
42
+
43
+ class UnsupportedFeature(SockudoException):
44
+ pass
45
+
46
+
47
+ class BadEventName(SockudoException):
48
+ pass
49
+
50
+
51
+ class AuthFailure(SockudoException):
52
+ def __init__(self, status_code: Optional[int], message: str) -> None:
53
+ super().__init__(message)
54
+ self.status_code = status_code
55
+
56
+
57
+ class DeltaFailure(SockudoException):
58
+ pass
59
+
60
+
61
+ class ConnectionState(str, Enum):
62
+ INITIALIZED = "initialized"
63
+ CONNECTING = "connecting"
64
+ CONNECTED = "connected"
65
+ DISCONNECTED = "disconnected"
66
+ UNAVAILABLE = "unavailable"
67
+ FAILED = "failed"
68
+
69
+
70
+ class SockudoTransport(str, Enum):
71
+ WS = "ws"
72
+ WSS = "wss"
73
+
74
+
75
+ class SockudoWireFormat(str, Enum):
76
+ JSON = "json"
77
+ MESSAGEPACK = "messagepack"
78
+ PROTOBUF = "protobuf"
79
+
80
+ @property
81
+ def is_binary(self) -> bool:
82
+ return self is not SockudoWireFormat.JSON
83
+
84
+
85
+ class DeltaAlgorithm(str, Enum):
86
+ FOSSIL = "fossil"
87
+ XDELTA3 = "xdelta3"
88
+
89
+
90
+ @dataclass
91
+ class FilterNode:
92
+ op: Optional[str] = None
93
+ key: Optional[str] = None
94
+ cmp: Optional[str] = None
95
+ val: Optional[str] = None
96
+ vals: Optional[List[str]] = None
97
+ nodes: Optional[List["FilterNode"]] = None
98
+
99
+ def to_dict(self) -> Dict[str, Any]:
100
+ return {
101
+ key: value
102
+ for key, value in {
103
+ "op": self.op,
104
+ "key": self.key,
105
+ "cmp": self.cmp,
106
+ "val": self.val,
107
+ "vals": self.vals,
108
+ "nodes": [node.to_dict() for node in self.nodes]
109
+ if self.nodes
110
+ else None,
111
+ }.items()
112
+ if value is not None
113
+ }
114
+
115
+
116
+ class Filter:
117
+ @staticmethod
118
+ def eq(key: str, value: str) -> FilterNode:
119
+ return FilterNode(key=key, cmp="eq", val=value)
120
+
121
+ @staticmethod
122
+ def neq(key: str, value: str) -> FilterNode:
123
+ return FilterNode(key=key, cmp="neq", val=value)
124
+
125
+ @staticmethod
126
+ def inside(key: str, values: List[str]) -> FilterNode:
127
+ return FilterNode(key=key, cmp="in", vals=values)
128
+
129
+ @staticmethod
130
+ def not_in(key: str, values: List[str]) -> FilterNode:
131
+ return FilterNode(key=key, cmp="nin", vals=values)
132
+
133
+ @staticmethod
134
+ def exists(key: str) -> FilterNode:
135
+ return FilterNode(key=key, cmp="ex")
136
+
137
+ @staticmethod
138
+ def not_exists(key: str) -> FilterNode:
139
+ return FilterNode(key=key, cmp="nex")
140
+
141
+ @staticmethod
142
+ def starts_with(key: str, value: str) -> FilterNode:
143
+ return FilterNode(key=key, cmp="sw", val=value)
144
+
145
+ @staticmethod
146
+ def ends_with(key: str, value: str) -> FilterNode:
147
+ return FilterNode(key=key, cmp="ew", val=value)
148
+
149
+ @staticmethod
150
+ def contains(key: str, value: str) -> FilterNode:
151
+ return FilterNode(key=key, cmp="ct", val=value)
152
+
153
+ @staticmethod
154
+ def gt(key: str, value: str) -> FilterNode:
155
+ return FilterNode(key=key, cmp="gt", val=value)
156
+
157
+ @staticmethod
158
+ def gte(key: str, value: str) -> FilterNode:
159
+ return FilterNode(key=key, cmp="gte", val=value)
160
+
161
+ @staticmethod
162
+ def lt(key: str, value: str) -> FilterNode:
163
+ return FilterNode(key=key, cmp="lt", val=value)
164
+
165
+ @staticmethod
166
+ def lte(key: str, value: str) -> FilterNode:
167
+ return FilterNode(key=key, cmp="lte", val=value)
168
+
169
+ @staticmethod
170
+ def and_(*nodes: FilterNode) -> FilterNode:
171
+ return FilterNode(op="and", nodes=list(nodes))
172
+
173
+ @staticmethod
174
+ def or_(*nodes: FilterNode) -> FilterNode:
175
+ return FilterNode(op="or", nodes=list(nodes))
176
+
177
+ @staticmethod
178
+ def not_(node: FilterNode) -> FilterNode:
179
+ return FilterNode(op="not", nodes=[node])
180
+
181
+
182
+ def validate_filter(filter_node: FilterNode) -> Optional[str]:
183
+ if filter_node.op:
184
+ if filter_node.op not in {"and", "or", "not"}:
185
+ return f"Invalid logical operator: {filter_node.op}"
186
+ if filter_node.nodes is None:
187
+ return f"Logical operation '{filter_node.op}' requires nodes array"
188
+ if filter_node.op == "not" and len(filter_node.nodes) != 1:
189
+ return f"NOT operation requires exactly one child node, got {len(filter_node.nodes)}"
190
+ if filter_node.op in {"and", "or"} and not filter_node.nodes:
191
+ return (
192
+ f"{filter_node.op.upper()} operation requires at least one child node"
193
+ )
194
+ for index, child in enumerate(filter_node.nodes):
195
+ error = validate_filter(child)
196
+ if error is not None:
197
+ return f"Child node {index}: {error}"
198
+ return None
199
+
200
+ if not filter_node.key:
201
+ return "Leaf node requires a key"
202
+ if not filter_node.cmp:
203
+ return "Leaf node requires a comparison operator"
204
+ if filter_node.cmp not in {
205
+ "eq",
206
+ "neq",
207
+ "in",
208
+ "nin",
209
+ "ex",
210
+ "nex",
211
+ "sw",
212
+ "ew",
213
+ "ct",
214
+ "gt",
215
+ "gte",
216
+ "lt",
217
+ "lte",
218
+ }:
219
+ return f"Invalid comparison operator: {filter_node.cmp}"
220
+ if filter_node.cmp in {"in", "nin"}:
221
+ if not filter_node.vals:
222
+ return f"{filter_node.cmp} operation requires non-empty vals array"
223
+ elif filter_node.cmp not in {"ex", "nex"} and not filter_node.val:
224
+ return f"{filter_node.cmp} operation requires a val"
225
+ return None
226
+
227
+
228
+ @dataclass
229
+ class ChannelDeltaSettings:
230
+ enabled: Optional[bool] = None
231
+ algorithm: Optional[DeltaAlgorithm] = None
232
+
233
+ def subscription_value(self) -> Any:
234
+ if self.enabled is None and self.algorithm is not None:
235
+ return self.algorithm.value
236
+ if self.enabled is False and self.algorithm is None:
237
+ return False
238
+ if self.enabled is True and self.algorithm is None:
239
+ return True
240
+ payload: Dict[str, Any] = {}
241
+ if self.enabled is not None:
242
+ payload["enabled"] = self.enabled
243
+ if self.algorithm is not None:
244
+ payload["algorithm"] = self.algorithm.value
245
+ return payload
246
+
247
+
248
+ @dataclass
249
+ class MessageExtras:
250
+ headers: Optional[Dict[str, Any]] = None
251
+ ephemeral: Optional[bool] = None
252
+ idempotency_key: Optional[str] = None
253
+ echo: Optional[bool] = None
254
+
255
+
256
+ @dataclass
257
+ class SubscriptionOptions:
258
+ filter: Optional[FilterNode] = None
259
+ delta: Optional[ChannelDeltaSettings] = None
260
+ events: Optional[List[str]] = None
261
+ rewind: Optional["SubscriptionRewind"] = None
262
+
263
+
264
+ @dataclass
265
+ class SubscriptionRewind:
266
+ count: Optional[int] = None
267
+ seconds: Optional[int] = None
268
+
269
+ @classmethod
270
+ def count_messages(cls, count: int) -> "SubscriptionRewind":
271
+ return cls(count=count)
272
+
273
+ @classmethod
274
+ def seconds_back(cls, seconds: int) -> "SubscriptionRewind":
275
+ return cls(seconds=seconds)
276
+
277
+ def subscription_value(self) -> Any:
278
+ if self.count is not None:
279
+ return self.count
280
+ if self.seconds is not None:
281
+ return {"seconds": self.seconds}
282
+ raise SockudoException("SubscriptionRewind requires count or seconds")
283
+
284
+
285
+ @dataclass
286
+ class DeltaStats:
287
+ total_messages: int = 0
288
+ delta_messages: int = 0
289
+ full_messages: int = 0
290
+ total_bytes_without_compression: int = 0
291
+ total_bytes_with_compression: int = 0
292
+ errors: int = 0
293
+
294
+ @property
295
+ def bandwidth_saved(self) -> int:
296
+ return self.total_bytes_without_compression - self.total_bytes_with_compression
297
+
298
+ @property
299
+ def bandwidth_saved_percent(self) -> float:
300
+ if self.total_bytes_without_compression == 0:
301
+ return 0.0
302
+ return self.bandwidth_saved / self.total_bytes_without_compression * 100.0
303
+
304
+
305
+ @dataclass
306
+ class DeltaOptions:
307
+ enabled: Optional[bool] = None
308
+ algorithms: List[DeltaAlgorithm] = field(
309
+ default_factory=lambda: [DeltaAlgorithm.FOSSIL, DeltaAlgorithm.XDELTA3]
310
+ )
311
+ debug: bool = False
312
+ on_stats: Optional[Callable[[DeltaStats], None]] = None
313
+ on_error: Optional[Callable[[BaseException], None]] = None
314
+
315
+
316
+ @dataclass
317
+ class PresenceMember:
318
+ id: str
319
+ info: Any
320
+
321
+
322
+ AuthValue = Union[str, int, float, bool]
323
+ HeadersProvider = Callable[[], Dict[str, str]]
324
+ ParamsProvider = Callable[[], Dict[str, AuthValue]]
325
+ ChannelAuthHandler = Callable[
326
+ ["ChannelAuthorizationRequest"], Awaitable["ChannelAuthorizationData"]
327
+ ]
328
+ UserAuthHandler = Callable[
329
+ ["UserAuthenticationRequest"], Awaitable["UserAuthenticationData"]
330
+ ]
331
+ PresenceHistoryHeadersProvider = Callable[[], Dict[str, str]]
332
+
333
+
334
+ @dataclass
335
+ class ChannelAuthorizationData:
336
+ auth: str
337
+ channel_data: Optional[str] = None
338
+ shared_secret: Optional[str] = None
339
+
340
+
341
+ @dataclass
342
+ class UserAuthenticationData:
343
+ auth: str
344
+ user_data: str
345
+
346
+
347
+ @dataclass
348
+ class ChannelAuthorizationRequest:
349
+ socket_id: str
350
+ channel_name: str
351
+
352
+
353
+ @dataclass
354
+ class UserAuthenticationRequest:
355
+ socket_id: str
356
+
357
+
358
+ @dataclass
359
+ class ChannelAuthorizationOptions:
360
+ endpoint: str = "/sockudo/auth"
361
+ headers: Dict[str, str] = field(default_factory=dict)
362
+ params: Dict[str, AuthValue] = field(default_factory=dict)
363
+ headers_provider: Optional[HeadersProvider] = None
364
+ params_provider: Optional[ParamsProvider] = None
365
+ custom_handler: Optional[ChannelAuthHandler] = None
366
+
367
+
368
+ @dataclass
369
+ class UserAuthenticationOptions:
370
+ endpoint: str = "/sockudo/user-auth"
371
+ headers: Dict[str, str] = field(default_factory=dict)
372
+ params: Dict[str, AuthValue] = field(default_factory=dict)
373
+ headers_provider: Optional[HeadersProvider] = None
374
+ params_provider: Optional[ParamsProvider] = None
375
+ custom_handler: Optional[UserAuthHandler] = None
376
+
377
+
378
+ @dataclass
379
+ class PresenceHistoryOptions:
380
+ endpoint: str
381
+ headers: Dict[str, str] = field(default_factory=dict)
382
+ headers_provider: Optional[PresenceHistoryHeadersProvider] = None
383
+
384
+
385
+ @dataclass
386
+ class SockudoOptions:
387
+ cluster: str
388
+ protocol_version: int = 2
389
+ activity_timeout: float = 120.0
390
+ force_tls: Optional[bool] = None
391
+ enabled_transports: Optional[List[SockudoTransport]] = None
392
+ disabled_transports: Optional[List[SockudoTransport]] = None
393
+ ws_host: Optional[str] = None
394
+ ws_port: int = 80
395
+ wss_port: int = 443
396
+ ws_path: str = ""
397
+ http_host: Optional[str] = None
398
+ http_port: int = 80
399
+ https_port: int = 443
400
+ http_path: str = "/sockudo"
401
+ pong_timeout: float = 30.0
402
+ unavailable_timeout: float = 10.0
403
+ enable_stats: bool = False
404
+ stats_host: str = "stats.sockudo.io"
405
+ timeline_params: Dict[str, AuthValue] = field(default_factory=dict)
406
+ channel_authorization: ChannelAuthorizationOptions = field(
407
+ default_factory=ChannelAuthorizationOptions
408
+ )
409
+ user_authentication: UserAuthenticationOptions = field(
410
+ default_factory=UserAuthenticationOptions
411
+ )
412
+ presence_history: Optional[PresenceHistoryOptions] = None
413
+ delta_compression: Optional[DeltaOptions] = None
414
+ message_deduplication: bool = True
415
+ message_deduplication_capacity: int = 1000
416
+ connection_recovery: bool = False
417
+ echo_messages: bool = True
418
+ wire_format: SockudoWireFormat = SockudoWireFormat.JSON
419
+
420
+
421
+ @dataclass
422
+ class EventMetadata:
423
+ user_id: Optional[str] = None
424
+
425
+
426
+ @dataclass
427
+ class PresenceHistoryParams:
428
+ direction: Optional[str] = None
429
+ limit: Optional[int] = None
430
+ cursor: Optional[str] = None
431
+ start_serial: Optional[int] = None
432
+ end_serial: Optional[int] = None
433
+ start_time_ms: Optional[int] = None
434
+ end_time_ms: Optional[int] = None
435
+ start: Optional[int] = None
436
+ end: Optional[int] = None
437
+
438
+ def to_payload(self) -> Dict[str, Any]:
439
+ payload: Dict[str, Any] = {}
440
+ if self.direction is not None:
441
+ payload["direction"] = self.direction
442
+ if self.limit is not None:
443
+ payload["limit"] = self.limit
444
+ if self.cursor is not None:
445
+ payload["cursor"] = self.cursor
446
+ if self.start_serial is not None:
447
+ payload["start_serial"] = self.start_serial
448
+ if self.end_serial is not None:
449
+ payload["end_serial"] = self.end_serial
450
+ if self.start_time_ms is not None:
451
+ payload["start_time_ms"] = self.start_time_ms
452
+ elif self.start is not None:
453
+ payload["start_time_ms"] = self.start
454
+ if self.end_time_ms is not None:
455
+ payload["end_time_ms"] = self.end_time_ms
456
+ elif self.end is not None:
457
+ payload["end_time_ms"] = self.end
458
+ return payload
459
+
460
+
461
+ @dataclass
462
+ class PresenceHistoryBounds:
463
+ start_serial: Optional[int]
464
+ end_serial: Optional[int]
465
+ start_time_ms: Optional[int]
466
+ end_time_ms: Optional[int]
467
+
468
+
469
+ @dataclass
470
+ class PresenceHistoryContinuity:
471
+ stream_id: Optional[str]
472
+ oldest_available_serial: Optional[int]
473
+ newest_available_serial: Optional[int]
474
+ oldest_available_published_at_ms: Optional[int]
475
+ newest_available_published_at_ms: Optional[int]
476
+ retained_events: int
477
+ retained_bytes: int
478
+ degraded: bool
479
+ complete: bool
480
+ truncated_by_retention: bool
481
+
482
+
483
+ @dataclass
484
+ class PresenceHistoryItem:
485
+ stream_id: str
486
+ serial: int
487
+ published_at_ms: int
488
+ event: str
489
+ cause: str
490
+ user_id: str
491
+ connection_id: Optional[str]
492
+ dead_node_id: Optional[str]
493
+ payload_size_bytes: int
494
+ presence_event: Dict[str, Any]
495
+
496
+
497
+ @dataclass
498
+ class PresenceSnapshotParams:
499
+ at_time_ms: Optional[int] = None
500
+ at: Optional[int] = None
501
+ at_serial: Optional[int] = None
502
+
503
+ def to_payload(self) -> Dict[str, Any]:
504
+ payload: Dict[str, Any] = {}
505
+ if self.at_time_ms is not None:
506
+ payload["at_time_ms"] = self.at_time_ms
507
+ elif self.at is not None:
508
+ payload["at_time_ms"] = self.at
509
+ if self.at_serial is not None:
510
+ payload["at_serial"] = self.at_serial
511
+ return payload
512
+
513
+
514
+ @dataclass
515
+ class PresenceSnapshotMember:
516
+ user_id: str
517
+ last_event: str
518
+ last_event_serial: int
519
+ last_event_at_ms: int
520
+
521
+
522
+ @dataclass
523
+ class PresenceSnapshot:
524
+ channel: str
525
+ members: List[PresenceSnapshotMember]
526
+ member_count: int
527
+ events_replayed: int
528
+ snapshot_serial: Optional[int]
529
+ snapshot_time_ms: Optional[int]
530
+ continuity: PresenceHistoryContinuity
531
+
532
+
533
+ @dataclass
534
+ class PresenceHistoryPage:
535
+ items: List[PresenceHistoryItem]
536
+ direction: str
537
+ limit: int
538
+ has_more: bool
539
+ next_cursor: Optional[str]
540
+ bounds: PresenceHistoryBounds
541
+ continuity: PresenceHistoryContinuity
542
+ _fetch_next: Optional[Callable[[str], Awaitable["PresenceHistoryPage"]]] = None
543
+
544
+ def has_next(self) -> bool:
545
+ return self.has_more and self.next_cursor is not None
546
+
547
+ async def next(self) -> "PresenceHistoryPage":
548
+ if not self.has_next() or self._fetch_next is None:
549
+ raise SockudoException("No more pages available")
550
+ return await self._fetch_next(self.next_cursor)
551
+
552
+
553
+ @dataclass
554
+ class SockudoEvent:
555
+ event: str
556
+ channel: Optional[str]
557
+ data: Any
558
+ user_id: Optional[str]
559
+ message_id: Optional[str]
560
+ stream_id: Optional[str]
561
+ raw_message: str
562
+ sequence: Optional[int] = None
563
+ conflation_key: Optional[str] = None
564
+ serial: Optional[int] = None
565
+ extras: Optional[MessageExtras] = None
566
+
567
+
568
+ @dataclass
569
+ class RecoveryPosition:
570
+ serial: int
571
+ stream_id: Optional[str] = None
572
+ last_message_id: Optional[str] = None
573
+
574
+
575
+ class ProtocolPrefix:
576
+ def __init__(self, version: int) -> None:
577
+ self.version = "2" if version >= 2 else "7"
578
+ self.event_prefix = "sockudo:" if version >= 2 else "pusher:"
579
+ self.internal_prefix = (
580
+ "sockudo_internal:" if version >= 2 else "pusher_internal:"
581
+ )
582
+
583
+ def event(self, name: str) -> str:
584
+ return f"{self.event_prefix}{name}"
585
+
586
+ def internal(self, name: str) -> str:
587
+ return f"{self.internal_prefix}{name}"
588
+
589
+ def is_internal_event(self, name: str) -> bool:
590
+ return name.startswith(self.internal_prefix)
591
+
592
+ def is_platform_event(self, name: str) -> bool:
593
+ return name.startswith(self.event_prefix)
594
+
595
+
596
+ class EventDispatcher:
597
+ def __init__(
598
+ self, failthrough: Optional[Callable[[str, Any], None]] = None
599
+ ) -> None:
600
+ self._callbacks: Dict[
601
+ str, "OrderedDict[str, Callable[[Any, Optional[EventMetadata]], None]]"
602
+ ] = {}
603
+ self._global_callbacks: "OrderedDict[str, Callable[[str, Any], None]]" = (
604
+ OrderedDict()
605
+ )
606
+ self._failthrough = failthrough
607
+
608
+ def bind(
609
+ self, event_name: str, callback: Callable[[Any, Optional[EventMetadata]], None]
610
+ ) -> str:
611
+ token = base64.urlsafe_b64encode(nacl_random(9)).decode("ascii")
612
+ self._callbacks.setdefault(event_name, OrderedDict())[token] = callback
613
+ return token
614
+
615
+ def bind_global(self, callback: Callable[[str, Any], None]) -> str:
616
+ token = base64.urlsafe_b64encode(nacl_random(9)).decode("ascii")
617
+ self._global_callbacks[token] = callback
618
+ return token
619
+
620
+ def unbind_global(self, token: Optional[str] = None) -> None:
621
+ if token is None:
622
+ self._global_callbacks.clear()
623
+ return
624
+ self._global_callbacks.pop(token, None)
625
+
626
+ def unbind(
627
+ self, event_name: Optional[str] = None, token: Optional[str] = None
628
+ ) -> None:
629
+ if event_name is not None and token is None:
630
+ self._callbacks.pop(event_name, None)
631
+ return
632
+ if event_name is not None and token is not None:
633
+ callbacks = self._callbacks.get(event_name)
634
+ if callbacks is None:
635
+ return
636
+ callbacks.pop(token, None)
637
+ if not callbacks:
638
+ self._callbacks.pop(event_name, None)
639
+ return
640
+ if token is not None:
641
+ for name in list(self._callbacks):
642
+ self._callbacks[name].pop(token, None)
643
+ if not self._callbacks[name]:
644
+ self._callbacks.pop(name, None)
645
+ self._global_callbacks.pop(token, None)
646
+ return
647
+ self._callbacks.clear()
648
+ self._global_callbacks.clear()
649
+
650
+ def emit(
651
+ self, event_name: str, data: Any, metadata: Optional[EventMetadata] = None
652
+ ) -> None:
653
+ for callback in self._global_callbacks.values():
654
+ callback(event_name, data)
655
+ callbacks = self._callbacks.get(event_name)
656
+ if not callbacks:
657
+ if self._failthrough is not None:
658
+ self._failthrough(event_name, data)
659
+ return
660
+ for callback in callbacks.values():
661
+ callback(data, metadata)
662
+
663
+
664
+ class MessageDeduplicator:
665
+ def __init__(self, capacity: int = 1000) -> None:
666
+ self._capacity = capacity
667
+ self._seen: "OrderedDict[str, bool]" = OrderedDict()
668
+
669
+ def is_duplicate(self, message_id: str) -> bool:
670
+ return message_id in self._seen
671
+
672
+ def track(self, message_id: str) -> None:
673
+ self._seen.pop(message_id, None)
674
+ self._seen[message_id] = True
675
+ while len(self._seen) > self._capacity:
676
+ self._seen.popitem(last=False)
677
+
678
+
679
+ class FossilDelta:
680
+ _digits = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz~"
681
+ _values = {ord(ch): index for index, ch in enumerate(_digits)}
682
+
683
+ class _Reader:
684
+ def __init__(self, data: bytes) -> None:
685
+ self.data = data
686
+ self.position = 0
687
+
688
+ @property
689
+ def has_bytes(self) -> bool:
690
+ return self.position < len(self.data)
691
+
692
+ def byte(self) -> int:
693
+ if self.position >= len(self.data):
694
+ raise DeltaFailure("out of bounds")
695
+ value = self.data[self.position]
696
+ self.position += 1
697
+ return value
698
+
699
+ def character(self) -> str:
700
+ return chr(self.byte())
701
+
702
+ def integer(self) -> int:
703
+ value = 0
704
+ while self.has_bytes:
705
+ raw = self.byte()
706
+ mapped = FossilDelta._values.get(raw, -1)
707
+ if mapped < 0:
708
+ self.position -= 1
709
+ break
710
+ value = (value << 6) + mapped
711
+ return value
712
+
713
+ @staticmethod
714
+ def apply(base: bytes, delta: bytes) -> bytes:
715
+ reader = FossilDelta._Reader(delta)
716
+ output_size = reader.integer()
717
+ if reader.character() != "\n":
718
+ raise DeltaFailure("size integer not terminated by newline")
719
+ output = bytearray()
720
+ total = 0
721
+ while reader.has_bytes:
722
+ count = reader.integer()
723
+ op = reader.character()
724
+ if op == "@":
725
+ offset = reader.integer()
726
+ if reader.has_bytes and reader.character() != ",":
727
+ raise DeltaFailure("copy command not terminated by comma")
728
+ total += count
729
+ if total > output_size:
730
+ raise DeltaFailure("copy exceeds output file size")
731
+ if offset + count > len(base):
732
+ raise DeltaFailure("copy extends past end of input")
733
+ output.extend(base[offset : offset + count])
734
+ elif op == ":":
735
+ total += count
736
+ if total > output_size:
737
+ raise DeltaFailure(
738
+ "insert command gives an output larger than predicted"
739
+ )
740
+ if reader.position + count > len(delta):
741
+ raise DeltaFailure("insert count exceeds size of delta")
742
+ output.extend(delta[reader.position : reader.position + count])
743
+ reader.position += count
744
+ elif op == ";":
745
+ payload = bytes(output)
746
+ if count != FossilDelta._checksum(payload):
747
+ raise DeltaFailure("bad checksum")
748
+ if total != output_size:
749
+ raise DeltaFailure("generated size does not match predicted size")
750
+ return payload
751
+ else:
752
+ raise DeltaFailure("unknown delta operator")
753
+ raise DeltaFailure("unterminated delta")
754
+
755
+ @staticmethod
756
+ def _checksum(data: bytes) -> int:
757
+ n_hash = 16
758
+ sum0 = sum1 = sum2 = sum3 = 0
759
+ index = 0
760
+ remaining = len(data)
761
+ while remaining >= n_hash:
762
+ sum0 += (
763
+ data[index + 0] + data[index + 4] + data[index + 8] + data[index + 12]
764
+ )
765
+ sum1 += (
766
+ data[index + 1] + data[index + 5] + data[index + 9] + data[index + 13]
767
+ )
768
+ sum2 += (
769
+ data[index + 2] + data[index + 6] + data[index + 10] + data[index + 14]
770
+ )
771
+ sum3 += (
772
+ data[index + 3] + data[index + 7] + data[index + 11] + data[index + 15]
773
+ )
774
+ index += n_hash
775
+ remaining -= n_hash
776
+ while remaining >= 4:
777
+ sum0 += data[index + 0]
778
+ sum1 += data[index + 1]
779
+ sum2 += data[index + 2]
780
+ sum3 += data[index + 3]
781
+ index += 4
782
+ remaining -= 4
783
+ sum3 += (sum2 << 8) + (sum1 << 16) + (sum0 << 24)
784
+ if remaining == 3:
785
+ sum3 += data[index + 2] << 8
786
+ sum3 += data[index + 1] << 16
787
+ sum3 += data[index + 0] << 24
788
+ elif remaining == 2:
789
+ sum3 += data[index + 1] << 16
790
+ sum3 += data[index + 0] << 24
791
+ elif remaining == 1:
792
+ sum3 += data[index + 0] << 24
793
+ return sum3
794
+
795
+
796
+ class ProtocolCodec:
797
+ _messagepack_fields = [
798
+ "event",
799
+ "channel",
800
+ "data",
801
+ "name",
802
+ "user_id",
803
+ "tags",
804
+ "sequence",
805
+ "conflation_key",
806
+ "message_id",
807
+ "serial",
808
+ "idempotency_key",
809
+ "extras",
810
+ "__delta_seq",
811
+ "__conflation_key",
812
+ "stream_id",
813
+ ]
814
+
815
+ @staticmethod
816
+ def encode_envelope(
817
+ envelope: Dict[str, Any], wire_format: SockudoWireFormat
818
+ ) -> Union[str, bytes]:
819
+ if wire_format is SockudoWireFormat.JSON:
820
+ return json.dumps(envelope, separators=(",", ":"))
821
+ if wire_format is SockudoWireFormat.MESSAGEPACK:
822
+ payload = [
823
+ envelope.get("event"),
824
+ envelope.get("channel"),
825
+ ProtocolCodec._encode_messagepack_data(envelope.get("data")),
826
+ envelope.get("name"),
827
+ envelope.get("user_id"),
828
+ envelope.get("tags"),
829
+ envelope.get("sequence"),
830
+ envelope.get("conflation_key"),
831
+ envelope.get("message_id"),
832
+ envelope.get("serial"),
833
+ envelope.get("idempotency_key"),
834
+ ProtocolCodec._encode_messagepack_extras(envelope.get("extras")),
835
+ envelope.get("__delta_seq"),
836
+ envelope.get("__conflation_key"),
837
+ envelope.get("stream_id"),
838
+ ]
839
+ return msgpack.packb(payload, use_bin_type=True)
840
+ return ProtocolCodec._encode_protobuf(envelope)
841
+
842
+ @staticmethod
843
+ def decode_event(
844
+ raw_message: Union[str, bytes], wire_format: SockudoWireFormat
845
+ ) -> SockudoEvent:
846
+ envelope, raw_text = ProtocolCodec.decode_envelope(raw_message, wire_format)
847
+ raw_data = envelope.get("data")
848
+ data = raw_data
849
+ if isinstance(raw_data, str):
850
+ try:
851
+ data = json.loads(raw_data)
852
+ except json.JSONDecodeError:
853
+ data = raw_data
854
+ return SockudoEvent(
855
+ event=envelope["event"],
856
+ channel=envelope.get("channel"),
857
+ data=data,
858
+ user_id=envelope.get("user_id"),
859
+ message_id=envelope.get("message_id"),
860
+ stream_id=envelope.get("stream_id"),
861
+ raw_message=raw_text,
862
+ sequence=_coerce_int(envelope.get("__delta_seq", envelope.get("sequence"))),
863
+ conflation_key=envelope.get(
864
+ "__conflation_key", envelope.get("conflation_key")
865
+ ),
866
+ serial=_coerce_int(envelope.get("serial")),
867
+ extras=ProtocolCodec._decode_extras(envelope.get("extras")),
868
+ )
869
+
870
+ @staticmethod
871
+ def decode_envelope(
872
+ raw_message: Union[str, bytes], wire_format: SockudoWireFormat
873
+ ) -> Tuple[Dict[str, Any], str]:
874
+ if wire_format is SockudoWireFormat.JSON:
875
+ text = (
876
+ raw_message.decode("utf-8")
877
+ if isinstance(raw_message, (bytes, bytearray))
878
+ else raw_message
879
+ )
880
+ decoded = json.loads(text)
881
+ if not isinstance(decoded, dict):
882
+ raise SockudoException("Unable to decode event envelope")
883
+ return decoded, text
884
+ if wire_format is SockudoWireFormat.MESSAGEPACK:
885
+ unpacked = msgpack.unpackb(
886
+ raw_message
887
+ if isinstance(raw_message, bytes)
888
+ else raw_message.encode("utf-8"),
889
+ raw=False,
890
+ )
891
+ if isinstance(unpacked, list):
892
+ envelope = {}
893
+ for index, field in enumerate(ProtocolCodec._messagepack_fields):
894
+ if index < len(unpacked):
895
+ value = ProtocolCodec._decode_messagepack_value(unpacked[index])
896
+ if value is not None:
897
+ envelope[field] = value
898
+ elif isinstance(unpacked, dict):
899
+ envelope = {
900
+ str(key): ProtocolCodec._decode_messagepack_value(value)
901
+ for key, value in unpacked.items()
902
+ }
903
+ else:
904
+ raise SockudoException("Unable to decode event envelope")
905
+ return envelope, json.dumps(envelope, separators=(",", ":"))
906
+ envelope = ProtocolCodec._decode_protobuf(
907
+ raw_message
908
+ if isinstance(raw_message, bytes)
909
+ else raw_message.encode("utf-8")
910
+ )
911
+ return envelope, json.dumps(envelope, separators=(",", ":"))
912
+
913
+ @staticmethod
914
+ def _encode_messagepack_data(value: Any) -> Any:
915
+ if value is None:
916
+ return None
917
+ if isinstance(value, str):
918
+ return ["string", value]
919
+ return ["json", json.dumps(value, separators=(",", ":"))]
920
+
921
+ @staticmethod
922
+ def _encode_messagepack_extras(raw_extras: Any) -> Any:
923
+ extras = ProtocolCodec._decode_extras(raw_extras)
924
+ if extras is None:
925
+ return None
926
+ encoded: Dict[str, Any] = {}
927
+ if extras.headers is not None:
928
+ encoded_headers = {}
929
+ for key, value in extras.headers.items():
930
+ if isinstance(value, bool):
931
+ encoded_headers[key] = ["bool", value]
932
+ elif isinstance(value, (int, float)):
933
+ encoded_headers[key] = ["number", float(value)]
934
+ else:
935
+ encoded_headers[key] = ["string", str(value)]
936
+ encoded["headers"] = encoded_headers
937
+ if extras.ephemeral is not None:
938
+ encoded["ephemeral"] = extras.ephemeral
939
+ if extras.idempotency_key is not None:
940
+ encoded["idempotency_key"] = extras.idempotency_key
941
+ if extras.echo is not None:
942
+ encoded["echo"] = extras.echo
943
+ return encoded
944
+
945
+ @staticmethod
946
+ def _decode_messagepack_value(value: Any) -> Any:
947
+ if isinstance(value, list):
948
+ if len(value) == 2 and isinstance(value[0], str):
949
+ tag = value[0]
950
+ if tag in {"string", "json", "number", "bool"}:
951
+ return value[1]
952
+ return [ProtocolCodec._decode_messagepack_value(item) for item in value]
953
+ if isinstance(value, dict):
954
+ return {
955
+ str(key): ProtocolCodec._decode_messagepack_value(item)
956
+ for key, item in value.items()
957
+ }
958
+ return value
959
+
960
+ @staticmethod
961
+ def _decode_extras(raw_extras: Any) -> Optional[MessageExtras]:
962
+ if raw_extras is None:
963
+ return None
964
+ if isinstance(raw_extras, MessageExtras):
965
+ return raw_extras
966
+ if not isinstance(raw_extras, dict):
967
+ return None
968
+ headers = raw_extras.get("headers")
969
+ if isinstance(headers, dict):
970
+ decoded_headers = {}
971
+ for key, value in headers.items():
972
+ if isinstance(value, list) and len(value) == 2:
973
+ decoded_headers[key] = value[1]
974
+ else:
975
+ decoded_headers[key] = value
976
+ headers = decoded_headers
977
+ return MessageExtras(
978
+ headers=headers if isinstance(headers, dict) else None,
979
+ ephemeral=raw_extras.get("ephemeral"),
980
+ idempotency_key=raw_extras.get("idempotency_key"),
981
+ echo=raw_extras.get("echo"),
982
+ )
983
+
984
+ @staticmethod
985
+ def _encode_protobuf(envelope: Dict[str, Any]) -> bytes:
986
+ output = bytearray()
987
+ _write_string_field(output, 1, envelope.get("event"))
988
+ _write_string_field(output, 2, envelope.get("channel"))
989
+ if "data" in envelope and envelope.get("data") is not None:
990
+ nested = bytearray()
991
+ data = envelope["data"]
992
+ if isinstance(data, str):
993
+ _write_string_field(nested, 1, data)
994
+ else:
995
+ _write_string_field(nested, 3, json.dumps(data, separators=(",", ":")))
996
+ _write_bytes_field(output, 3, bytes(nested))
997
+ _write_string_field(output, 5, envelope.get("user_id"))
998
+ _write_uint_field(output, 7, envelope.get("sequence"))
999
+ _write_string_field(output, 8, envelope.get("conflation_key"))
1000
+ _write_string_field(output, 9, envelope.get("message_id"))
1001
+ _write_uint_field(output, 10, envelope.get("serial"))
1002
+ extras = ProtocolCodec._encode_protobuf_extras(envelope.get("extras"))
1003
+ if extras is not None:
1004
+ _write_bytes_field(output, 12, extras)
1005
+ _write_uint_field(output, 13, envelope.get("__delta_seq"))
1006
+ _write_string_field(output, 14, envelope.get("__conflation_key"))
1007
+ _write_string_field(output, 15, envelope.get("stream_id"))
1008
+ return bytes(output)
1009
+
1010
+ @staticmethod
1011
+ def _encode_protobuf_extras(raw_extras: Any) -> Optional[bytes]:
1012
+ extras = ProtocolCodec._decode_extras(raw_extras)
1013
+ if extras is None:
1014
+ return None
1015
+ output = bytearray()
1016
+ if extras.headers:
1017
+ for key, value in extras.headers.items():
1018
+ entry = bytearray()
1019
+ _write_string_field(entry, 1, key)
1020
+ value_bytes = bytearray()
1021
+ if isinstance(value, bool):
1022
+ _write_bool_field(value_bytes, 3, value)
1023
+ elif isinstance(value, (int, float)):
1024
+ _write_double_field(value_bytes, 2, float(value))
1025
+ else:
1026
+ _write_string_field(value_bytes, 1, str(value))
1027
+ _write_bytes_field(entry, 2, bytes(value_bytes))
1028
+ _write_bytes_field(output, 1, bytes(entry))
1029
+ _write_optional_bool_field(output, 2, extras.ephemeral)
1030
+ _write_string_field(output, 3, extras.idempotency_key)
1031
+ _write_optional_bool_field(output, 4, extras.echo)
1032
+ return bytes(output)
1033
+
1034
+ @staticmethod
1035
+ def _decode_protobuf(payload: bytes) -> Dict[str, Any]:
1036
+ index = 0
1037
+ envelope: Dict[str, Any] = {}
1038
+ while index < len(payload):
1039
+ tag, index = _read_varint(payload, index)
1040
+ field = tag >> 3
1041
+ wire = tag & 0x7
1042
+ if field in {1, 2, 5, 8, 9, 14, 15}:
1043
+ value, index = _read_length_delimited(payload, index)
1044
+ envelope[
1045
+ {
1046
+ 1: "event",
1047
+ 2: "channel",
1048
+ 5: "user_id",
1049
+ 8: "conflation_key",
1050
+ 9: "message_id",
1051
+ 14: "__conflation_key",
1052
+ 15: "stream_id",
1053
+ }[field]
1054
+ ] = value.decode("utf-8")
1055
+ elif field in {7, 10, 13}:
1056
+ value, index = _read_varint(payload, index)
1057
+ envelope[{7: "sequence", 10: "serial", 13: "__delta_seq"}[field]] = (
1058
+ value
1059
+ )
1060
+ elif field == 3:
1061
+ value, index = _read_length_delimited(payload, index)
1062
+ envelope["data"] = ProtocolCodec._decode_proto_data(value)
1063
+ elif field == 12:
1064
+ value, index = _read_length_delimited(payload, index)
1065
+ envelope["extras"] = ProtocolCodec._decode_proto_extras(value)
1066
+ else:
1067
+ index = _skip_unknown(payload, index, wire)
1068
+ return envelope
1069
+
1070
+ @staticmethod
1071
+ def _decode_proto_data(payload: bytes) -> Any:
1072
+ index = 0
1073
+ data: Dict[int, Any] = {}
1074
+ while index < len(payload):
1075
+ tag, index = _read_varint(payload, index)
1076
+ field = tag >> 3
1077
+ wire = tag & 0x7
1078
+ if field in {1, 3} and wire == 2:
1079
+ value, index = _read_length_delimited(payload, index)
1080
+ data[field] = value.decode("utf-8")
1081
+ else:
1082
+ index = _skip_unknown(payload, index, wire)
1083
+ if 1 in data:
1084
+ return data[1]
1085
+ if 3 in data:
1086
+ return data[3]
1087
+ return None
1088
+
1089
+ @staticmethod
1090
+ def _decode_proto_extras(payload: bytes) -> Dict[str, Any]:
1091
+ index = 0
1092
+ result: Dict[str, Any] = {}
1093
+ headers: Dict[str, Any] = {}
1094
+ while index < len(payload):
1095
+ tag, index = _read_varint(payload, index)
1096
+ field = tag >> 3
1097
+ wire = tag & 0x7
1098
+ if field == 1 and wire == 2:
1099
+ entry, index = _read_length_delimited(payload, index)
1100
+ key, value = ProtocolCodec._decode_proto_header_entry(entry)
1101
+ if key is not None:
1102
+ headers[key] = value
1103
+ elif field == 2 and wire == 0:
1104
+ value, index = _read_varint(payload, index)
1105
+ result["ephemeral"] = bool(value)
1106
+ elif field == 3 and wire == 2:
1107
+ value, index = _read_length_delimited(payload, index)
1108
+ result["idempotency_key"] = value.decode("utf-8")
1109
+ elif field == 4 and wire == 0:
1110
+ value, index = _read_varint(payload, index)
1111
+ result["echo"] = bool(value)
1112
+ else:
1113
+ index = _skip_unknown(payload, index, wire)
1114
+ if headers:
1115
+ result["headers"] = headers
1116
+ return result
1117
+
1118
+ @staticmethod
1119
+ def _decode_proto_header_entry(payload: bytes) -> Tuple[Optional[str], Any]:
1120
+ index = 0
1121
+ key = None
1122
+ value = None
1123
+ while index < len(payload):
1124
+ tag, index = _read_varint(payload, index)
1125
+ field = tag >> 3
1126
+ wire = tag & 0x7
1127
+ if field == 1 and wire == 2:
1128
+ raw, index = _read_length_delimited(payload, index)
1129
+ key = raw.decode("utf-8")
1130
+ elif field == 2 and wire == 2:
1131
+ raw, index = _read_length_delimited(payload, index)
1132
+ value = ProtocolCodec._decode_proto_extra_value(raw)
1133
+ else:
1134
+ index = _skip_unknown(payload, index, wire)
1135
+ return key, value
1136
+
1137
+ @staticmethod
1138
+ def _decode_proto_extra_value(payload: bytes) -> Any:
1139
+ index = 0
1140
+ while index < len(payload):
1141
+ tag, index = _read_varint(payload, index)
1142
+ field = tag >> 3
1143
+ wire = tag & 0x7
1144
+ if field == 1 and wire == 2:
1145
+ raw, index = _read_length_delimited(payload, index)
1146
+ return raw.decode("utf-8")
1147
+ if field == 2 and wire == 1:
1148
+ return struct.unpack("<d", payload[index : index + 8])[0]
1149
+ if field == 3 and wire == 0:
1150
+ raw, index = _read_varint(payload, index)
1151
+ return bool(raw)
1152
+ index = _skip_unknown(payload, index, wire)
1153
+ return None
1154
+
1155
+
1156
+ class DeltaCompressionManager:
1157
+ def __init__(
1158
+ self,
1159
+ options: DeltaOptions,
1160
+ send_event: Callable[[str, Any, Optional[str]], Awaitable[bool]],
1161
+ prefix: ProtocolPrefix,
1162
+ ) -> None:
1163
+ self._options = options
1164
+ self._send_event = send_event
1165
+ self._prefix = prefix
1166
+ self._enabled = False
1167
+ self._default_algorithm = DeltaAlgorithm.FOSSIL
1168
+ self._stats = DeltaStats()
1169
+ self._channel_states: Dict[str, Dict[str, Any]] = {}
1170
+
1171
+ async def enable(self) -> None:
1172
+ if self._enabled:
1173
+ return
1174
+ await self._send_event(
1175
+ self._prefix.event("enable_delta_compression"),
1176
+ {"algorithms": [algorithm.value for algorithm in self._options.algorithms]},
1177
+ None,
1178
+ )
1179
+
1180
+ def handle_enabled(self, data: Any) -> None:
1181
+ payload = data if isinstance(data, dict) else {}
1182
+ self._enabled = payload.get("enabled", True)
1183
+ if "algorithm" in payload:
1184
+ try:
1185
+ self._default_algorithm = DeltaAlgorithm(payload["algorithm"])
1186
+ except ValueError:
1187
+ pass
1188
+
1189
+ def handle_cache_sync(self, channel: str, data: Any) -> None:
1190
+ payload = data if isinstance(data, dict) else {}
1191
+ self._channel_states[channel] = {
1192
+ "conflation_key": payload.get("conflation_key"),
1193
+ "states": payload.get("states", {}),
1194
+ "base_message": None,
1195
+ }
1196
+
1197
+ async def handle_delta_message(
1198
+ self, channel: str, data: Any
1199
+ ) -> Optional[SockudoEvent]:
1200
+ payload = data if isinstance(data, dict) else {}
1201
+ event_name = payload.get("event")
1202
+ delta_payload = payload.get("delta")
1203
+ if not isinstance(event_name, str) or not isinstance(delta_payload, str):
1204
+ return None
1205
+ algorithm = payload.get("algorithm", self._default_algorithm.value)
1206
+ sequence = _coerce_int(payload.get("seq"))
1207
+ base_state = self._channel_states.get(channel)
1208
+ if base_state is None or base_state.get("base_message") is None:
1209
+ await self._send_event(
1210
+ self._prefix.event("delta_sync_error"), {"channel": channel}, None
1211
+ )
1212
+ self._channel_states.pop(channel, None)
1213
+ return None
1214
+ try:
1215
+ delta_bytes = base64.b64decode(delta_payload)
1216
+ if algorithm == DeltaAlgorithm.XDELTA3.value:
1217
+ reconstructed = vcdiff_decoder.decode(
1218
+ base_state["base_message"].encode("utf-8"), delta_bytes
1219
+ ).decode("utf-8")
1220
+ else:
1221
+ reconstructed = FossilDelta.apply(
1222
+ base_state["base_message"].encode("utf-8"), delta_bytes
1223
+ ).decode("utf-8")
1224
+ parsed = json.loads(reconstructed)
1225
+ event_data = (
1226
+ parsed.get("data")
1227
+ if isinstance(parsed, dict) and "data" in parsed
1228
+ else parsed
1229
+ )
1230
+ self.handle_full_message(
1231
+ channel, reconstructed, sequence, payload.get("conflation_key")
1232
+ )
1233
+ self._stats.delta_messages += 1
1234
+ self._stats.total_messages += 1
1235
+ self._stats.total_bytes_without_compression += len(reconstructed)
1236
+ self._stats.total_bytes_with_compression += len(delta_bytes)
1237
+ if self._options.on_stats:
1238
+ self._options.on_stats(self._stats)
1239
+ return SockudoEvent(
1240
+ event=event_name,
1241
+ channel=channel,
1242
+ data=event_data,
1243
+ user_id=None,
1244
+ message_id=None,
1245
+ raw_message=reconstructed,
1246
+ sequence=sequence,
1247
+ conflation_key=payload.get("conflation_key"),
1248
+ )
1249
+ except BaseException as exc:
1250
+ self._stats.errors += 1
1251
+ if self._options.on_error:
1252
+ self._options.on_error(exc)
1253
+ return None
1254
+
1255
+ def handle_full_message(
1256
+ self,
1257
+ channel: str,
1258
+ raw_message: str,
1259
+ sequence: Optional[int],
1260
+ conflation_key: Optional[str],
1261
+ ) -> None:
1262
+ self._channel_states.setdefault(channel, {})["base_message"] = raw_message
1263
+ self._stats.full_messages += 1
1264
+ self._stats.total_messages += 1
1265
+ self._stats.total_bytes_without_compression += len(raw_message)
1266
+ self._stats.total_bytes_with_compression += len(raw_message)
1267
+ if self._options.on_stats:
1268
+ self._options.on_stats(self._stats)
1269
+
1270
+ def get_stats(self) -> DeltaStats:
1271
+ return self._stats
1272
+
1273
+ def reset_stats(self) -> None:
1274
+ self._stats = DeltaStats()
1275
+
1276
+ def clear_channel_state(self, channel: str) -> None:
1277
+ self._channel_states.pop(channel, None)
1278
+
1279
+
1280
+ class PresenceMembers:
1281
+ def __init__(self) -> None:
1282
+ self._members: Dict[str, Any] = {}
1283
+ self.count: int = 0
1284
+ self.my_id: Optional[str] = None
1285
+ self.me: Optional[PresenceMember] = None
1286
+
1287
+ def member(self, member_id: str) -> Optional[PresenceMember]:
1288
+ if member_id not in self._members:
1289
+ return None
1290
+ return PresenceMember(member_id, self._members[member_id])
1291
+
1292
+ def remember_my_id(self, member_id: str) -> None:
1293
+ self.my_id = member_id
1294
+
1295
+ def apply_subscription_data(self, data: Dict[str, Any]) -> None:
1296
+ presence = data.get("presence", {})
1297
+ hash_data = presence.get("hash", {}) if isinstance(presence, dict) else {}
1298
+ self._members = dict(hash_data)
1299
+ self.count = (
1300
+ int(presence.get("count", len(self._members)))
1301
+ if isinstance(presence, dict)
1302
+ else len(self._members)
1303
+ )
1304
+ self.me = self.member(self.my_id) if self.my_id else None
1305
+
1306
+ def add(self, data: Dict[str, Any]) -> Optional[PresenceMember]:
1307
+ user_id = data.get("user_id")
1308
+ if not isinstance(user_id, str):
1309
+ return None
1310
+ if user_id not in self._members:
1311
+ self.count += 1
1312
+ self._members[user_id] = data.get("user_info")
1313
+ return PresenceMember(user_id, self._members[user_id])
1314
+
1315
+ def remove(self, data: Dict[str, Any]) -> Optional[PresenceMember]:
1316
+ user_id = data.get("user_id")
1317
+ if not isinstance(user_id, str) or user_id not in self._members:
1318
+ return None
1319
+ info = self._members.pop(user_id)
1320
+ self.count = max(0, self.count - 1)
1321
+ return PresenceMember(user_id, info)
1322
+
1323
+ def reset(self) -> None:
1324
+ self._members.clear()
1325
+ self.count = 0
1326
+ self.my_id = None
1327
+ self.me = None
1328
+
1329
+
1330
+ class SockudoChannel:
1331
+ def __init__(self, name: str, client: "SockudoClient") -> None:
1332
+ self.name = name
1333
+ self.client = client
1334
+ self.dispatcher = EventDispatcher()
1335
+ self.is_subscribed = False
1336
+ self.subscription_pending = False
1337
+ self.subscription_cancelled = False
1338
+ self.subscription_count: Optional[int] = None
1339
+ self.filter: Optional[FilterNode] = None
1340
+ self.delta_settings: Optional[ChannelDeltaSettings] = None
1341
+ self.events_filter: Optional[List[str]] = None
1342
+ self.rewind: Optional[SubscriptionRewind] = None
1343
+
1344
+ def bind(
1345
+ self, event_name: str, callback: Callable[[Any, Optional[EventMetadata]], None]
1346
+ ) -> str:
1347
+ return self.dispatcher.bind(event_name, callback)
1348
+
1349
+ def bind_global(self, callback: Callable[[str, Any], None]) -> str:
1350
+ return self.dispatcher.bind_global(callback)
1351
+
1352
+ def unbind(
1353
+ self, event_name: Optional[str] = None, token: Optional[str] = None
1354
+ ) -> None:
1355
+ self.dispatcher.unbind(event_name, token)
1356
+
1357
+ async def trigger(self, event: str, data: Any) -> bool:
1358
+ if not event.startswith("client-"):
1359
+ raise BadEventName(f"Event '{event}' does not start with 'client-'")
1360
+ return await self.client.send_event(event, data, self.name)
1361
+
1362
+ async def authorize(self, socket_id: str) -> ChannelAuthorizationData:
1363
+ return ChannelAuthorizationData(auth="")
1364
+
1365
+ def subscribe_if_possible(self) -> None:
1366
+ if self.subscription_pending and self.subscription_cancelled:
1367
+ self.subscription_cancelled = False
1368
+ elif (
1369
+ not self.subscription_pending
1370
+ and self.client.connection_state is ConnectionState.CONNECTED
1371
+ ):
1372
+ asyncio.create_task(self.subscribe())
1373
+
1374
+ async def subscribe(self) -> None:
1375
+ if self.is_subscribed:
1376
+ return
1377
+ self.subscription_pending = True
1378
+ self.subscription_cancelled = False
1379
+ try:
1380
+ auth = await self.authorize(self.client.socket_id or "")
1381
+ payload: Dict[str, Any] = {"auth": auth.auth, "channel": self.name}
1382
+ if auth.channel_data is not None:
1383
+ payload["channel_data"] = auth.channel_data
1384
+ if self.filter is not None:
1385
+ payload["tags_filter"] = self.filter.to_dict()
1386
+ if self.delta_settings is not None:
1387
+ payload["delta"] = self.delta_settings.subscription_value()
1388
+ if self.events_filter is not None:
1389
+ payload["events"] = self.events_filter
1390
+ if self.rewind is not None:
1391
+ payload["rewind"] = self.rewind.subscription_value()
1392
+ await self.client.send_event(
1393
+ self.client.prefix.event("subscribe"), payload, None
1394
+ )
1395
+ except BaseException as exc:
1396
+ self.subscription_pending = False
1397
+ self.dispatcher.emit(
1398
+ self.client.prefix.event("subscription_error"),
1399
+ {"type": "AuthError", "error": str(exc)},
1400
+ )
1401
+
1402
+ async def unsubscribe(self) -> None:
1403
+ self.is_subscribed = False
1404
+ await self.client.send_event(
1405
+ self.client.prefix.event("unsubscribe"), {"channel": self.name}, None
1406
+ )
1407
+
1408
+ def disconnect(self) -> None:
1409
+ self.is_subscribed = False
1410
+ self.subscription_pending = False
1411
+
1412
+ def handle(self, event: SockudoEvent) -> None:
1413
+ p = self.client.prefix
1414
+ if event.event == p.internal("subscription_succeeded"):
1415
+ self.subscription_pending = False
1416
+ self.is_subscribed = True
1417
+ if self.subscription_cancelled:
1418
+ asyncio.create_task(self.client.unsubscribe(self.name))
1419
+ else:
1420
+ self.dispatcher.emit(p.event("subscription_succeeded"), event.data)
1421
+ elif event.event == p.internal("subscription_count"):
1422
+ if isinstance(event.data, dict):
1423
+ self.subscription_count = _coerce_int(
1424
+ event.data.get("subscription_count")
1425
+ )
1426
+ self.dispatcher.emit(p.event("subscription_count"), event.data)
1427
+ elif not p.is_internal_event(event.event):
1428
+ self.dispatcher.emit(
1429
+ event.event, event.data, EventMetadata(user_id=event.user_id)
1430
+ )
1431
+
1432
+
1433
+ class PrivateChannel(SockudoChannel):
1434
+ async def authorize(self, socket_id: str) -> ChannelAuthorizationData:
1435
+ return await self.client.config.authorize_channel(
1436
+ ChannelAuthorizationRequest(socket_id, self.name)
1437
+ )
1438
+
1439
+
1440
+ class PresenceChannel(PrivateChannel):
1441
+ def __init__(self, name: str, client: "SockudoClient") -> None:
1442
+ super().__init__(name, client)
1443
+ self.members = PresenceMembers()
1444
+
1445
+ async def authorize(self, socket_id: str) -> ChannelAuthorizationData:
1446
+ response = await super().authorize(socket_id)
1447
+ if response.channel_data:
1448
+ parsed = json.loads(response.channel_data)
1449
+ if isinstance(parsed, dict) and isinstance(parsed.get("user_id"), str):
1450
+ self.members.remember_my_id(parsed["user_id"])
1451
+ return response
1452
+ if self.client.user.user_id:
1453
+ self.members.remember_my_id(self.client.user.user_id)
1454
+ return response
1455
+ raise AuthFailure(
1456
+ None, f"Invalid auth response for presence channel '{self.name}'"
1457
+ )
1458
+
1459
+ def handle(self, event: SockudoEvent) -> None:
1460
+ p = self.client.prefix
1461
+ if event.event == p.internal("subscription_succeeded"):
1462
+ self.subscription_pending = False
1463
+ self.is_subscribed = True
1464
+ payload = event.data if isinstance(event.data, dict) else {}
1465
+ self.members.apply_subscription_data(payload)
1466
+ self.dispatcher.emit(p.event("subscription_succeeded"), self.members)
1467
+ elif event.event == p.internal("member_added") and isinstance(event.data, dict):
1468
+ member = self.members.add(event.data)
1469
+ if member is not None:
1470
+ self.dispatcher.emit(p.event("member_added"), member)
1471
+ elif event.event == p.internal("member_removed") and isinstance(
1472
+ event.data, dict
1473
+ ):
1474
+ member = self.members.remove(event.data)
1475
+ if member is not None:
1476
+ self.dispatcher.emit(p.event("member_removed"), member)
1477
+ else:
1478
+ super().handle(event)
1479
+
1480
+ def disconnect(self) -> None:
1481
+ self.members.reset()
1482
+ super().disconnect()
1483
+
1484
+ async def history(
1485
+ self, params: Optional[PresenceHistoryParams] = None
1486
+ ) -> PresenceHistoryPage:
1487
+ return await self.client.config.fetch_presence_history(
1488
+ self.name, params or PresenceHistoryParams()
1489
+ )
1490
+
1491
+ async def snapshot(
1492
+ self, params: Optional[PresenceSnapshotParams] = None
1493
+ ) -> PresenceSnapshot:
1494
+ return await self.client.config.fetch_presence_snapshot(
1495
+ self.name, params or PresenceSnapshotParams()
1496
+ )
1497
+
1498
+
1499
+ class EncryptedChannel(PrivateChannel):
1500
+ def __init__(self, name: str, client: "SockudoClient") -> None:
1501
+ super().__init__(name, client)
1502
+ self.shared_secret: Optional[bytes] = None
1503
+
1504
+ async def authorize(self, socket_id: str) -> ChannelAuthorizationData:
1505
+ response = await super().authorize(socket_id)
1506
+ if not response.shared_secret:
1507
+ raise AuthFailure(
1508
+ None,
1509
+ f"No shared_secret key in auth payload for encrypted channel: {self.name}",
1510
+ )
1511
+ self.shared_secret = base64.b64decode(response.shared_secret)
1512
+ return ChannelAuthorizationData(
1513
+ auth=response.auth, channel_data=response.channel_data
1514
+ )
1515
+
1516
+ async def trigger(self, event: str, data: Any) -> bool:
1517
+ raise UnsupportedFeature(
1518
+ "Client events are not currently supported for encrypted channels"
1519
+ )
1520
+
1521
+ def handle(self, event: SockudoEvent) -> None:
1522
+ if self.client.prefix.is_internal_event(
1523
+ event.event
1524
+ ) or self.client.prefix.is_platform_event(event.event):
1525
+ super().handle(event)
1526
+ return
1527
+ if self.shared_secret is None or not isinstance(event.data, dict):
1528
+ return
1529
+ cipher_text = event.data.get("ciphertext")
1530
+ nonce = event.data.get("nonce")
1531
+ if not isinstance(cipher_text, str) or not isinstance(nonce, str):
1532
+ return
1533
+ box = SecretBox(self.shared_secret)
1534
+ decrypted = box.decrypt(
1535
+ base64.b64decode(cipher_text), base64.b64decode(nonce)
1536
+ ).decode("utf-8")
1537
+ parsed = json.loads(decrypted)
1538
+ self.dispatcher.emit(event.event, parsed, EventMetadata(user_id=event.user_id))
1539
+
1540
+
1541
+ class _ResolvedConfiguration:
1542
+ def __init__(self, options: SockudoOptions) -> None:
1543
+ self.cluster = options.cluster
1544
+ self.activity_timeout = options.activity_timeout
1545
+ self.use_tls = options.force_tls is not False
1546
+ self.ws_host = options.ws_host or f"ws-{options.cluster}.sockudo.io"
1547
+ self.ws_port = options.ws_port
1548
+ self.wss_port = options.wss_port
1549
+ self.ws_path = options.ws_path
1550
+ self.http_host = options.http_host or f"sockjs-{options.cluster}.sockudo.io"
1551
+ self.http_port = options.http_port
1552
+ self.https_port = options.https_port
1553
+ self.http_path = options.http_path
1554
+ self.pong_timeout = options.pong_timeout
1555
+ self.unavailable_timeout = options.unavailable_timeout
1556
+ self.enabled_transports = options.enabled_transports
1557
+ self.disabled_transports = options.disabled_transports
1558
+ self.channel_options = options.channel_authorization
1559
+ self.user_options = options.user_authentication
1560
+ self.presence_history = options.presence_history
1561
+ self._http_client = httpx.AsyncClient()
1562
+
1563
+ async def authorize_channel(
1564
+ self, request: ChannelAuthorizationRequest
1565
+ ) -> ChannelAuthorizationData:
1566
+ if self.channel_options.custom_handler is not None:
1567
+ return await self.channel_options.custom_handler(request)
1568
+ params = dict(self.channel_options.params)
1569
+ if self.channel_options.params_provider:
1570
+ params.update(self.channel_options.params_provider())
1571
+ params["socket_id"] = request.socket_id
1572
+ params["channel_name"] = request.channel_name
1573
+ headers = dict(self.channel_options.headers)
1574
+ if self.channel_options.headers_provider:
1575
+ headers.update(self.channel_options.headers_provider())
1576
+ payload = await self._perform_auth_request(
1577
+ self.channel_options.endpoint, headers, params
1578
+ )
1579
+ auth = payload.get("auth")
1580
+ if not isinstance(auth, str):
1581
+ raise AuthFailure(200, "JSON returned from auth endpoint was invalid")
1582
+ return ChannelAuthorizationData(
1583
+ auth=auth,
1584
+ channel_data=payload.get("channel_data"),
1585
+ shared_secret=payload.get("shared_secret"),
1586
+ )
1587
+
1588
+ async def authenticate_user(
1589
+ self, request: UserAuthenticationRequest
1590
+ ) -> UserAuthenticationData:
1591
+ if self.user_options.custom_handler is not None:
1592
+ return await self.user_options.custom_handler(request)
1593
+ params = dict(self.user_options.params)
1594
+ if self.user_options.params_provider:
1595
+ params.update(self.user_options.params_provider())
1596
+ params["socket_id"] = request.socket_id
1597
+ headers = dict(self.user_options.headers)
1598
+ if self.user_options.headers_provider:
1599
+ headers.update(self.user_options.headers_provider())
1600
+ payload = await self._perform_auth_request(
1601
+ self.user_options.endpoint, headers, params
1602
+ )
1603
+ auth = payload.get("auth")
1604
+ user_data = payload.get("user_data")
1605
+ if not isinstance(auth, str) or not isinstance(user_data, str):
1606
+ raise AuthFailure(200, "JSON returned from auth endpoint was invalid")
1607
+ return UserAuthenticationData(auth=auth, user_data=user_data)
1608
+
1609
+ async def close(self) -> None:
1610
+ await self._http_client.aclose()
1611
+
1612
+ async def fetch_presence_history(
1613
+ self, channel_name: str, params: PresenceHistoryParams
1614
+ ) -> PresenceHistoryPage:
1615
+ config = self.presence_history
1616
+ if config is None:
1617
+ raise UnsupportedFeature(
1618
+ "presence_history.endpoint must be configured to use presence.history(). "
1619
+ "This endpoint should proxy requests to the Sockudo server REST API."
1620
+ )
1621
+
1622
+ payload = await self._perform_presence_history_request(
1623
+ config.endpoint,
1624
+ config.headers,
1625
+ config.headers_provider,
1626
+ channel_name,
1627
+ params.to_payload(),
1628
+ "history",
1629
+ )
1630
+ return self._decode_presence_history_page(
1631
+ payload,
1632
+ lambda cursor: self.fetch_presence_history(
1633
+ channel_name,
1634
+ PresenceHistoryParams(
1635
+ direction=params.direction,
1636
+ limit=params.limit,
1637
+ cursor=cursor,
1638
+ start_serial=params.start_serial,
1639
+ end_serial=params.end_serial,
1640
+ start_time_ms=params.start_time_ms,
1641
+ end_time_ms=params.end_time_ms,
1642
+ start=params.start,
1643
+ end=params.end,
1644
+ ),
1645
+ ),
1646
+ )
1647
+
1648
+ async def fetch_presence_snapshot(
1649
+ self, channel_name: str, params: PresenceSnapshotParams
1650
+ ) -> PresenceSnapshot:
1651
+ config = self.presence_history
1652
+ if config is None:
1653
+ raise UnsupportedFeature(
1654
+ "presence_history.endpoint must be configured to use presence.snapshot(). "
1655
+ "This endpoint should proxy requests to the Sockudo server REST API."
1656
+ )
1657
+
1658
+ payload = await self._perform_presence_history_request(
1659
+ config.endpoint,
1660
+ config.headers,
1661
+ config.headers_provider,
1662
+ channel_name,
1663
+ params.to_payload(),
1664
+ "snapshot",
1665
+ )
1666
+ return self._decode_presence_snapshot(payload)
1667
+
1668
+ async def _perform_auth_request(
1669
+ self, endpoint: str, headers: Dict[str, str], params: Dict[str, AuthValue]
1670
+ ) -> Dict[str, Any]:
1671
+ response = await self._http_client.post(
1672
+ endpoint,
1673
+ headers=headers,
1674
+ content=urllib.parse.urlencode(
1675
+ {
1676
+ key: str(value).lower() if isinstance(value, bool) else value
1677
+ for key, value in params.items()
1678
+ }
1679
+ ),
1680
+ )
1681
+ if response.status_code >= 400:
1682
+ raise AuthFailure(
1683
+ response.status_code,
1684
+ f"Could not get auth info from endpoint, status: {response.status_code}",
1685
+ )
1686
+ payload = response.json()
1687
+ if not isinstance(payload, dict):
1688
+ raise AuthFailure(
1689
+ response.status_code, "JSON returned from auth endpoint was invalid"
1690
+ )
1691
+ return payload
1692
+
1693
+ async def _perform_presence_history_request(
1694
+ self,
1695
+ endpoint: str,
1696
+ headers: Dict[str, str],
1697
+ headers_provider: Optional[PresenceHistoryHeadersProvider],
1698
+ channel_name: str,
1699
+ params: Dict[str, Any],
1700
+ action: str,
1701
+ ) -> Dict[str, Any]:
1702
+ merged_headers = {"Content-Type": "application/json", **headers}
1703
+ if headers_provider:
1704
+ merged_headers.update(headers_provider())
1705
+ response = await self._http_client.post(
1706
+ endpoint,
1707
+ headers=merged_headers,
1708
+ content=json.dumps(
1709
+ {
1710
+ "channel": channel_name,
1711
+ "params": params,
1712
+ "action": action,
1713
+ }
1714
+ ),
1715
+ )
1716
+ if response.status_code >= 400:
1717
+ raise SockudoException(
1718
+ f"Presence {action} request failed ({response.status_code}): "
1719
+ f"{response.text}"
1720
+ )
1721
+ payload = response.json()
1722
+ if not isinstance(payload, dict):
1723
+ raise SockudoException(f"Presence {action} endpoint returned invalid JSON")
1724
+ return payload
1725
+
1726
+ def _decode_presence_history_page(
1727
+ self,
1728
+ payload: Dict[str, Any],
1729
+ fetch_next: Callable[[str], Awaitable[PresenceHistoryPage]],
1730
+ ) -> PresenceHistoryPage:
1731
+ return PresenceHistoryPage(
1732
+ items=[
1733
+ PresenceHistoryItem(
1734
+ stream_id=str(item["stream_id"]),
1735
+ serial=int(item["serial"]),
1736
+ published_at_ms=int(item["published_at_ms"]),
1737
+ event=str(item["event"]),
1738
+ cause=str(item["cause"]),
1739
+ user_id=str(item["user_id"]),
1740
+ connection_id=(
1741
+ str(item["connection_id"])
1742
+ if item.get("connection_id") is not None
1743
+ else None
1744
+ ),
1745
+ dead_node_id=(
1746
+ str(item["dead_node_id"])
1747
+ if item.get("dead_node_id") is not None
1748
+ else None
1749
+ ),
1750
+ payload_size_bytes=int(item["payload_size_bytes"]),
1751
+ presence_event=dict(item.get("presence_event") or {}),
1752
+ )
1753
+ for item in payload.get("items", [])
1754
+ if isinstance(item, dict)
1755
+ ],
1756
+ direction=str(payload.get("direction", "oldest_first")),
1757
+ limit=int(payload.get("limit", 0)),
1758
+ has_more=bool(payload.get("has_more", False)),
1759
+ next_cursor=(
1760
+ str(payload["next_cursor"])
1761
+ if payload.get("next_cursor") is not None
1762
+ else None
1763
+ ),
1764
+ bounds=self._decode_presence_history_bounds(payload.get("bounds")),
1765
+ continuity=self._decode_presence_history_continuity(
1766
+ payload.get("continuity")
1767
+ ),
1768
+ _fetch_next=fetch_next,
1769
+ )
1770
+
1771
+ def _decode_presence_snapshot(self, payload: Dict[str, Any]) -> PresenceSnapshot:
1772
+ return PresenceSnapshot(
1773
+ channel=str(payload.get("channel", "")),
1774
+ members=[
1775
+ PresenceSnapshotMember(
1776
+ user_id=str(member["user_id"]),
1777
+ last_event=str(member["last_event"]),
1778
+ last_event_serial=int(member["last_event_serial"]),
1779
+ last_event_at_ms=int(member["last_event_at_ms"]),
1780
+ )
1781
+ for member in payload.get("members", [])
1782
+ if isinstance(member, dict)
1783
+ ],
1784
+ member_count=int(payload.get("member_count", 0)),
1785
+ events_replayed=int(payload.get("events_replayed", 0)),
1786
+ snapshot_serial=(
1787
+ int(payload["snapshot_serial"])
1788
+ if payload.get("snapshot_serial") is not None
1789
+ else None
1790
+ ),
1791
+ snapshot_time_ms=(
1792
+ int(payload["snapshot_time_ms"])
1793
+ if payload.get("snapshot_time_ms") is not None
1794
+ else None
1795
+ ),
1796
+ continuity=self._decode_presence_history_continuity(
1797
+ payload.get("continuity")
1798
+ ),
1799
+ )
1800
+
1801
+ def _decode_presence_history_bounds(self, payload: Any) -> PresenceHistoryBounds:
1802
+ if not isinstance(payload, dict):
1803
+ payload = {}
1804
+ return PresenceHistoryBounds(
1805
+ start_serial=(
1806
+ int(payload["start_serial"])
1807
+ if payload.get("start_serial") is not None
1808
+ else None
1809
+ ),
1810
+ end_serial=(
1811
+ int(payload["end_serial"])
1812
+ if payload.get("end_serial") is not None
1813
+ else None
1814
+ ),
1815
+ start_time_ms=(
1816
+ int(payload["start_time_ms"])
1817
+ if payload.get("start_time_ms") is not None
1818
+ else None
1819
+ ),
1820
+ end_time_ms=(
1821
+ int(payload["end_time_ms"])
1822
+ if payload.get("end_time_ms") is not None
1823
+ else None
1824
+ ),
1825
+ )
1826
+
1827
+ def _decode_presence_history_continuity(
1828
+ self, payload: Any
1829
+ ) -> PresenceHistoryContinuity:
1830
+ if not isinstance(payload, dict):
1831
+ payload = {}
1832
+ return PresenceHistoryContinuity(
1833
+ stream_id=(
1834
+ str(payload["stream_id"])
1835
+ if payload.get("stream_id") is not None
1836
+ else None
1837
+ ),
1838
+ oldest_available_serial=(
1839
+ int(payload["oldest_available_serial"])
1840
+ if payload.get("oldest_available_serial") is not None
1841
+ else None
1842
+ ),
1843
+ newest_available_serial=(
1844
+ int(payload["newest_available_serial"])
1845
+ if payload.get("newest_available_serial") is not None
1846
+ else None
1847
+ ),
1848
+ oldest_available_published_at_ms=(
1849
+ int(payload["oldest_available_published_at_ms"])
1850
+ if payload.get("oldest_available_published_at_ms") is not None
1851
+ else None
1852
+ ),
1853
+ newest_available_published_at_ms=(
1854
+ int(payload["newest_available_published_at_ms"])
1855
+ if payload.get("newest_available_published_at_ms") is not None
1856
+ else None
1857
+ ),
1858
+ retained_events=int(payload.get("retained_events", 0)),
1859
+ retained_bytes=int(payload.get("retained_bytes", 0)),
1860
+ degraded=bool(payload.get("degraded", False)),
1861
+ complete=bool(payload.get("complete", False)),
1862
+ truncated_by_retention=bool(payload.get("truncated_by_retention", False)),
1863
+ )
1864
+
1865
+
1866
+ class SockudoClient:
1867
+ def __init__(self, key: str, options: SockudoOptions) -> None:
1868
+ if not key:
1869
+ raise InvalidAppKey(
1870
+ "You must pass your app key when you instantiate SockudoClient."
1871
+ )
1872
+ if not options.cluster:
1873
+ raise InvalidOptions("Options must provide a cluster.")
1874
+ self.key = key
1875
+ self.options = options
1876
+ self.prefix = ProtocolPrefix(options.protocol_version)
1877
+ self.config = _ResolvedConfiguration(options)
1878
+ self.dispatcher = EventDispatcher()
1879
+ self.channels: Dict[str, SockudoChannel] = {}
1880
+ self.socket = None
1881
+ self.connection_state = ConnectionState.INITIALIZED
1882
+ self.socket_id: Optional[str] = None
1883
+ self._receive_task: Optional[asyncio.Task[Any]] = None
1884
+ self._activity_task: Optional[asyncio.Task[Any]] = None
1885
+ self._retry_task: Optional[asyncio.Task[Any]] = None
1886
+ self._unavailable_task: Optional[asyncio.Task[Any]] = None
1887
+ self._manually_disconnected = False
1888
+ self._current_transport: Optional[SockudoTransport] = None
1889
+ self._attempted_fallback = False
1890
+ self._channel_positions: Dict[str, RecoveryPosition] = {}
1891
+ self._deduplicator = (
1892
+ MessageDeduplicator(options.message_deduplication_capacity)
1893
+ if options.message_deduplication
1894
+ else None
1895
+ )
1896
+ self._delta_manager = (
1897
+ DeltaCompressionManager(
1898
+ options.delta_compression, self.send_event, self.prefix
1899
+ )
1900
+ if options.delta_compression
1901
+ else None
1902
+ )
1903
+ self.user = self.UserFacade(self)
1904
+ self.watchlist = self.WatchlistFacade()
1905
+
1906
+ def bind(
1907
+ self, event_name: str, callback: Callable[[Any, Optional[EventMetadata]], None]
1908
+ ) -> str:
1909
+ return self.dispatcher.bind(event_name, callback)
1910
+
1911
+ def bind_global(self, callback: Callable[[str, Any], None]) -> str:
1912
+ return self.dispatcher.bind_global(callback)
1913
+
1914
+ def channel(self, name: str) -> Optional[SockudoChannel]:
1915
+ return self.channels.get(name)
1916
+
1917
+ def subscribe(
1918
+ self, channel_name: str, options: Optional[SubscriptionOptions] = None
1919
+ ) -> SockudoChannel:
1920
+ channel = self.channels.get(channel_name)
1921
+ if channel is None:
1922
+ channel = self._create_channel(channel_name)
1923
+ self.channels[channel_name] = channel
1924
+ if options is not None:
1925
+ channel.filter = options.filter
1926
+ channel.delta_settings = options.delta
1927
+ channel.events_filter = options.events
1928
+ channel.rewind = options.rewind
1929
+ channel.subscribe_if_possible()
1930
+ return channel
1931
+
1932
+ async def unsubscribe(self, channel_name: str) -> None:
1933
+ channel = self.channels.get(channel_name)
1934
+ if channel is None:
1935
+ return
1936
+ if channel.subscription_pending:
1937
+ channel.subscription_cancelled = True
1938
+ elif channel.is_subscribed:
1939
+ self.channels.pop(channel_name, None)
1940
+ await channel.unsubscribe()
1941
+ else:
1942
+ self.channels.pop(channel_name, None)
1943
+ self._channel_positions.pop(channel_name, None)
1944
+ if self._delta_manager is not None:
1945
+ self._delta_manager.clear_channel_state(channel_name)
1946
+
1947
+ async def connect(self) -> None:
1948
+ if self.socket is not None:
1949
+ return
1950
+ transports = self._transport_sequence()
1951
+ if not transports:
1952
+ self._update_state(ConnectionState.FAILED)
1953
+ return
1954
+ self._manually_disconnected = False
1955
+ self._attempted_fallback = False
1956
+ self._update_state(ConnectionState.CONNECTING)
1957
+ await self._open_websocket(transports[0])
1958
+ self._set_unavailable_timer()
1959
+
1960
+ async def disconnect(self) -> None:
1961
+ self._manually_disconnected = True
1962
+ self._cancel_timers()
1963
+ if self.socket is not None:
1964
+ await self.socket.close()
1965
+ self.socket = None
1966
+ for channel in self.channels.values():
1967
+ channel.disconnect()
1968
+ self._update_state(ConnectionState.DISCONNECTED)
1969
+ await self.config.close()
1970
+
1971
+ async def close(self) -> None:
1972
+ await self.disconnect()
1973
+
1974
+ async def signin(self) -> None:
1975
+ await self.user.sign_in()
1976
+
1977
+ async def send_event(self, name: str, data: Any, channel: Optional[str]) -> bool:
1978
+ if self.socket is None:
1979
+ return False
1980
+ payload: Dict[str, Any] = {"event": name, "data": data}
1981
+ if channel is not None:
1982
+ payload["channel"] = channel
1983
+ encoded = ProtocolCodec.encode_envelope(payload, self.options.wire_format)
1984
+ await self.socket.send(encoded)
1985
+ return True
1986
+
1987
+ def get_delta_stats(self) -> Optional[DeltaStats]:
1988
+ return self._delta_manager.get_stats() if self._delta_manager else None
1989
+
1990
+ async def _open_websocket(self, transport: SockudoTransport) -> None:
1991
+ self._current_transport = transport
1992
+ self.socket = await ws_connect(self._socket_url(transport))
1993
+ self._receive_task = asyncio.create_task(self._receive_loop())
1994
+
1995
+ async def _receive_loop(self) -> None:
1996
+ try:
1997
+ assert self.socket is not None
1998
+ async for raw_message in self.socket:
1999
+ await self._handle_raw_message(raw_message)
2000
+ except ConnectionClosed as exc:
2001
+ await self._handle_socket_closed(exc.code, exc.reason)
2002
+
2003
+ async def _handle_raw_message(self, raw_message: Union[str, bytes]) -> None:
2004
+ try:
2005
+ event = ProtocolCodec.decode_event(raw_message, self.options.wire_format)
2006
+ if event.message_id and self._deduplicator:
2007
+ if self._deduplicator.is_duplicate(event.message_id):
2008
+ return
2009
+ self._deduplicator.track(event.message_id)
2010
+ self._reset_activity_timer()
2011
+ if (
2012
+ self.options.connection_recovery
2013
+ and event.channel
2014
+ and event.serial is not None
2015
+ ):
2016
+ self._channel_positions[event.channel] = RecoveryPosition(
2017
+ serial=event.serial,
2018
+ stream_id=event.stream_id,
2019
+ last_message_id=event.message_id,
2020
+ )
2021
+ event_name = event.event
2022
+ if event_name == self.prefix.event("connection_established"):
2023
+ payload = event.data if isinstance(event.data, dict) else {}
2024
+ self.socket_id = payload.get("socket_id")
2025
+ if not isinstance(self.socket_id, str):
2026
+ raise SockudoException("Invalid handshake")
2027
+ self._update_state(
2028
+ ConnectionState.CONNECTED, {"socket_id": self.socket_id}
2029
+ )
2030
+ for channel in self.channels.values():
2031
+ channel.subscribe_if_possible()
2032
+ if self.options.connection_recovery and self._channel_positions:
2033
+ await self.send_event(
2034
+ self.prefix.event("resume"),
2035
+ {
2036
+ "channel_positions": {
2037
+ channel_name: {
2038
+ key: value
2039
+ for key, value in {
2040
+ "serial": position.serial,
2041
+ "stream_id": position.stream_id,
2042
+ "last_message_id": position.last_message_id,
2043
+ }.items()
2044
+ if value is not None
2045
+ }
2046
+ for channel_name, position in self._channel_positions.items()
2047
+ }
2048
+ },
2049
+ None,
2050
+ )
2051
+ if (
2052
+ self.options.delta_compression
2053
+ and self.options.delta_compression.enabled
2054
+ ):
2055
+ assert self._delta_manager is not None
2056
+ await self._delta_manager.enable()
2057
+ await self.user.handle_connected()
2058
+ elif event_name == self.prefix.event("error"):
2059
+ self.dispatcher.emit("error", event.data)
2060
+ elif event_name == self.prefix.event("ping"):
2061
+ await self.send_event(self.prefix.event("pong"), {}, None)
2062
+ elif event_name == self.prefix.event("signin_success"):
2063
+ await self.user.handle_sign_in_success(event.data)
2064
+ elif event_name == self.prefix.event("resume_failed"):
2065
+ payload = event.data if isinstance(event.data, dict) else {}
2066
+ failed_channel_name = payload.get("channel")
2067
+ if isinstance(failed_channel_name, str):
2068
+ self._channel_positions.pop(failed_channel_name, None)
2069
+ if self._delta_manager is not None:
2070
+ self._delta_manager.clear_channel_state(failed_channel_name)
2071
+ failed_channel = self.channels.get(failed_channel_name)
2072
+ if failed_channel is not None:
2073
+ failed_channel.is_subscribed = False
2074
+ failed_channel.subscription_pending = False
2075
+ failed_channel.subscribe_if_possible()
2076
+ self.dispatcher.emit(event_name, event.data)
2077
+ elif event_name == self.prefix.internal("watchlist_events"):
2078
+ self.watchlist.handle(event.data)
2079
+ elif (
2080
+ event_name == self.prefix.event("delta_compression_enabled")
2081
+ and self._delta_manager
2082
+ ):
2083
+ self._delta_manager.handle_enabled(event.data)
2084
+ self.dispatcher.emit(event_name, event.data)
2085
+ elif (
2086
+ event_name == self.prefix.event("delta_cache_sync")
2087
+ and self._delta_manager
2088
+ and event.channel
2089
+ ):
2090
+ self._delta_manager.handle_cache_sync(event.channel, event.data)
2091
+ elif (
2092
+ event_name == self.prefix.event("delta")
2093
+ and self._delta_manager
2094
+ and event.channel
2095
+ ):
2096
+ reconstructed = await self._delta_manager.handle_delta_message(
2097
+ event.channel, event.data
2098
+ )
2099
+ if reconstructed is not None:
2100
+ channel = self.channels.get(event.channel)
2101
+ if channel is not None:
2102
+ channel.handle(reconstructed)
2103
+ self.dispatcher.emit(reconstructed.event, reconstructed.data)
2104
+ else:
2105
+ if event.channel and event.channel in self.channels:
2106
+ self.channels[event.channel].handle(event)
2107
+ if (
2108
+ not self.prefix.is_platform_event(event_name)
2109
+ and not self.prefix.is_internal_event(event_name)
2110
+ and event.sequence is not None
2111
+ and self._delta_manager is not None
2112
+ ):
2113
+ self._delta_manager.handle_full_message(
2114
+ event.channel,
2115
+ self._strip_delta_metadata(event.raw_message),
2116
+ event.sequence,
2117
+ event.conflation_key,
2118
+ )
2119
+ if not self.prefix.is_internal_event(event_name):
2120
+ self.dispatcher.emit(
2121
+ event_name, event.data, EventMetadata(user_id=event.user_id)
2122
+ )
2123
+ except BaseException as exc:
2124
+ self.dispatcher.emit("error", exc)
2125
+
2126
+ async def _handle_socket_closed(self, code: int, reason: str) -> None:
2127
+ self.socket = None
2128
+ self._cancel_activity_timer()
2129
+ self._clear_unavailable_timer()
2130
+ for channel in self.channels.values():
2131
+ channel.disconnect()
2132
+ if not self._manually_disconnected:
2133
+ await self._schedule_retry(1.0)
2134
+ if reason:
2135
+ self.dispatcher.emit("error", reason)
2136
+
2137
+ async def _schedule_retry(self, after_seconds: float) -> None:
2138
+ if self._manually_disconnected:
2139
+ return
2140
+ if self._retry_task:
2141
+ self._retry_task.cancel()
2142
+
2143
+ async def _retry() -> None:
2144
+ await asyncio.sleep(after_seconds)
2145
+ self._update_state(ConnectionState.CONNECTING)
2146
+ transports = self._transport_sequence()
2147
+ if (
2148
+ self._current_transport is SockudoTransport.WS
2149
+ and not self._attempted_fallback
2150
+ and SockudoTransport.WSS in transports
2151
+ ):
2152
+ self._attempted_fallback = True
2153
+ await self._open_websocket(SockudoTransport.WSS)
2154
+ else:
2155
+ self._attempted_fallback = False
2156
+ await self._open_websocket(
2157
+ transports[0] if transports else SockudoTransport.WSS
2158
+ )
2159
+ self._set_unavailable_timer()
2160
+
2161
+ self._retry_task = asyncio.create_task(_retry())
2162
+
2163
+ def _socket_url(self, transport: SockudoTransport) -> str:
2164
+ scheme = "wss" if transport is SockudoTransport.WSS else "ws"
2165
+ host = self.config.ws_host
2166
+ port = (
2167
+ self.config.wss_port
2168
+ if transport is SockudoTransport.WSS
2169
+ else self.config.ws_port
2170
+ )
2171
+ path = f"{self.config.ws_path}/app/{self.key}"
2172
+ query = {
2173
+ "protocol": self.prefix.version,
2174
+ "client": "python",
2175
+ "version": "2.0.0",
2176
+ "flash": "false",
2177
+ }
2178
+ if self.options.protocol_version >= 2:
2179
+ query["format"] = self.options.wire_format.value
2180
+ query["echo_messages"] = "true" if self.options.echo_messages else "false"
2181
+ return urllib.parse.urlunsplit(
2182
+ (scheme, f"{host}:{port}", path, urllib.parse.urlencode(query), "")
2183
+ )
2184
+
2185
+ def _transport_sequence(self) -> List[SockudoTransport]:
2186
+ transports = (
2187
+ [SockudoTransport.WSS]
2188
+ if self.config.use_tls
2189
+ else [SockudoTransport.WS, SockudoTransport.WSS]
2190
+ )
2191
+ if self.config.enabled_transports is not None:
2192
+ transports = [
2193
+ transport
2194
+ for transport in transports
2195
+ if transport in self.config.enabled_transports
2196
+ ]
2197
+ if self.config.disabled_transports is not None:
2198
+ transports = [
2199
+ transport
2200
+ for transport in transports
2201
+ if transport not in self.config.disabled_transports
2202
+ ]
2203
+ return transports
2204
+
2205
+ def _create_channel(self, name: str) -> SockudoChannel:
2206
+ if name.startswith("private-encrypted-"):
2207
+ return EncryptedChannel(name, self)
2208
+ if name.startswith("presence-"):
2209
+ return PresenceChannel(name, self)
2210
+ if name.startswith("private-"):
2211
+ return PrivateChannel(name, self)
2212
+ return SockudoChannel(name, self)
2213
+
2214
+ def _update_state(
2215
+ self, state: ConnectionState, metadata: Optional[Dict[str, Any]] = None
2216
+ ) -> None:
2217
+ previous = self.connection_state
2218
+ self.connection_state = state
2219
+ self.dispatcher.emit(
2220
+ "state_change", {"previous": previous.value, "current": state.value}
2221
+ )
2222
+ self.dispatcher.emit(state.value, metadata)
2223
+
2224
+ def _cancel_activity_timer(self) -> None:
2225
+ if self._activity_task:
2226
+ self._activity_task.cancel()
2227
+ self._activity_task = None
2228
+
2229
+ def _reset_activity_timer(self) -> None:
2230
+ self._cancel_activity_timer()
2231
+
2232
+ async def _timer() -> None:
2233
+ await asyncio.sleep(self.config.activity_timeout)
2234
+ await self.send_event(self.prefix.event("ping"), {}, None)
2235
+
2236
+ self._activity_task = asyncio.create_task(_timer())
2237
+
2238
+ def _set_unavailable_timer(self) -> None:
2239
+ self._clear_unavailable_timer()
2240
+
2241
+ async def _timer() -> None:
2242
+ await asyncio.sleep(self.config.unavailable_timeout)
2243
+ self._update_state(ConnectionState.UNAVAILABLE)
2244
+
2245
+ self._unavailable_task = asyncio.create_task(_timer())
2246
+
2247
+ def _clear_unavailable_timer(self) -> None:
2248
+ if self._unavailable_task:
2249
+ self._unavailable_task.cancel()
2250
+ self._unavailable_task = None
2251
+
2252
+ def _cancel_timers(self) -> None:
2253
+ self._cancel_activity_timer()
2254
+ self._clear_unavailable_timer()
2255
+ if self._retry_task:
2256
+ self._retry_task.cancel()
2257
+ self._retry_task = None
2258
+
2259
+ @staticmethod
2260
+ def _strip_delta_metadata(raw_message: str) -> str:
2261
+ return raw_message.replace(',"__delta_seq"', "").replace(
2262
+ ',"__conflation_key"', ""
2263
+ )
2264
+
2265
+ def reset_delta_stats(self) -> None:
2266
+ if self._delta_manager is not None:
2267
+ self._delta_manager.reset_stats()
2268
+
2269
+ class UserFacade:
2270
+ def __init__(self, client: "SockudoClient") -> None:
2271
+ self.client = client
2272
+ self.dispatcher = EventDispatcher()
2273
+ self.is_sign_in_requested = False
2274
+ self.user_data: Optional[Dict[str, Any]] = None
2275
+ self.server_channel: Optional[SockudoChannel] = None
2276
+
2277
+ @property
2278
+ def user_id(self) -> Optional[str]:
2279
+ if self.user_data is None:
2280
+ return None
2281
+ value = self.user_data.get("id")
2282
+ return value if isinstance(value, str) else None
2283
+
2284
+ def bind(
2285
+ self,
2286
+ event_name: str,
2287
+ callback: Callable[[Any, Optional[EventMetadata]], None],
2288
+ ) -> str:
2289
+ return self.dispatcher.bind(event_name, callback)
2290
+
2291
+ async def sign_in(self) -> None:
2292
+ self.is_sign_in_requested = True
2293
+ await self._attempt_sign_in()
2294
+
2295
+ async def handle_connected(self) -> None:
2296
+ await self._attempt_sign_in()
2297
+
2298
+ async def handle_sign_in_success(self, data: Any) -> None:
2299
+ payload = data if isinstance(data, dict) else {}
2300
+ user_data = payload.get("user_data")
2301
+ if not isinstance(user_data, str):
2302
+ self._cleanup()
2303
+ return
2304
+ parsed = json.loads(user_data)
2305
+ if not isinstance(parsed, dict) or not isinstance(parsed.get("id"), str):
2306
+ self._cleanup()
2307
+ return
2308
+ self.user_data = parsed
2309
+ await self._subscribe_server_channel(parsed["id"])
2310
+
2311
+ async def _attempt_sign_in(self) -> None:
2312
+ if (
2313
+ not self.is_sign_in_requested
2314
+ or self.client.connection_state is not ConnectionState.CONNECTED
2315
+ ):
2316
+ return
2317
+ if not self.client.socket_id:
2318
+ return
2319
+ try:
2320
+ auth = await self.client.config.authenticate_user(
2321
+ UserAuthenticationRequest(self.client.socket_id)
2322
+ )
2323
+ await self.client.send_event(
2324
+ self.client.prefix.event("signin"),
2325
+ {"auth": auth.auth, "user_data": auth.user_data},
2326
+ None,
2327
+ )
2328
+ except BaseException:
2329
+ self._cleanup()
2330
+
2331
+ async def _subscribe_server_channel(self, user_id: str) -> None:
2332
+ channel = SockudoChannel(f"#server-to-user-{user_id}", self.client)
2333
+ channel.bind_global(
2334
+ lambda event_name, data: (
2335
+ self.dispatcher.emit(event_name, data)
2336
+ if not self.client.prefix.is_internal_event(event_name)
2337
+ and not self.client.prefix.is_platform_event(event_name)
2338
+ else None
2339
+ )
2340
+ )
2341
+ self.server_channel = channel
2342
+ channel.subscribe_if_possible()
2343
+
2344
+ def _cleanup(self) -> None:
2345
+ self.user_data = None
2346
+ if self.server_channel is not None:
2347
+ self.server_channel.unbind()
2348
+ self.server_channel.disconnect()
2349
+ self.server_channel = None
2350
+
2351
+ class WatchlistFacade:
2352
+ def __init__(self) -> None:
2353
+ self.dispatcher = EventDispatcher()
2354
+
2355
+ def bind(
2356
+ self,
2357
+ event_name: str,
2358
+ callback: Callable[[Any, Optional[EventMetadata]], None],
2359
+ ) -> str:
2360
+ return self.dispatcher.bind(event_name, callback)
2361
+
2362
+ def handle(self, data: Any) -> None:
2363
+ payload = data if isinstance(data, dict) else {}
2364
+ events = payload.get("events", [])
2365
+ if not isinstance(events, list):
2366
+ return
2367
+ for event in events:
2368
+ if isinstance(event, dict) and isinstance(event.get("name"), str):
2369
+ self.dispatcher.emit(event["name"], event)
2370
+
2371
+
2372
+ def _coerce_int(value: Any) -> Optional[int]:
2373
+ if isinstance(value, bool):
2374
+ return None
2375
+ if isinstance(value, int):
2376
+ return value
2377
+ if isinstance(value, float):
2378
+ return int(value)
2379
+ return None
2380
+
2381
+
2382
+ def _write_varint(buffer: bytearray, value: int) -> None:
2383
+ while True:
2384
+ if value < 0x80:
2385
+ buffer.append(value)
2386
+ return
2387
+ buffer.append((value & 0x7F) | 0x80)
2388
+ value >>= 7
2389
+
2390
+
2391
+ def _write_key(buffer: bytearray, field: int, wire_type: int) -> None:
2392
+ _write_varint(buffer, (field << 3) | wire_type)
2393
+
2394
+
2395
+ def _write_string_field(buffer: bytearray, field: int, value: Any) -> None:
2396
+ if not isinstance(value, str):
2397
+ return
2398
+ encoded = value.encode("utf-8")
2399
+ _write_key(buffer, field, 2)
2400
+ _write_varint(buffer, len(encoded))
2401
+ buffer.extend(encoded)
2402
+
2403
+
2404
+ def _write_bytes_field(buffer: bytearray, field: int, payload: bytes) -> None:
2405
+ _write_key(buffer, field, 2)
2406
+ _write_varint(buffer, len(payload))
2407
+ buffer.extend(payload)
2408
+
2409
+
2410
+ def _write_uint_field(buffer: bytearray, field: int, value: Any) -> None:
2411
+ coerced = _coerce_int(value)
2412
+ if coerced is None:
2413
+ return
2414
+ _write_key(buffer, field, 0)
2415
+ _write_varint(buffer, coerced)
2416
+
2417
+
2418
+ def _write_optional_bool_field(
2419
+ buffer: bytearray, field: int, value: Optional[bool]
2420
+ ) -> None:
2421
+ if value is None:
2422
+ return
2423
+ _write_key(buffer, field, 0)
2424
+ _write_varint(buffer, 1 if value else 0)
2425
+
2426
+
2427
+ def _write_bool_field(buffer: bytearray, field: int, value: bool) -> None:
2428
+ _write_key(buffer, field, 0)
2429
+ _write_varint(buffer, 1 if value else 0)
2430
+
2431
+
2432
+ def _write_double_field(buffer: bytearray, field: int, value: float) -> None:
2433
+ _write_key(buffer, field, 1)
2434
+ buffer.extend(struct.pack("<d", value))
2435
+
2436
+
2437
+ def _read_varint(data: bytes, index: int) -> Tuple[int, int]:
2438
+ shift = 0
2439
+ result = 0
2440
+ while True:
2441
+ byte = data[index]
2442
+ index += 1
2443
+ result |= (byte & 0x7F) << shift
2444
+ if byte & 0x80 == 0:
2445
+ return result, index
2446
+ shift += 7
2447
+
2448
+
2449
+ def _read_length_delimited(data: bytes, index: int) -> Tuple[bytes, int]:
2450
+ length, index = _read_varint(data, index)
2451
+ return data[index : index + length], index + length
2452
+
2453
+
2454
+ def _skip_unknown(data: bytes, index: int, wire: int) -> int:
2455
+ if wire == 0:
2456
+ _, index = _read_varint(data, index)
2457
+ return index
2458
+ if wire == 1:
2459
+ return index + 8
2460
+ if wire == 2:
2461
+ payload, index = _read_length_delimited(data, index)
2462
+ return index
2463
+ if wire == 5:
2464
+ return index + 4
2465
+ raise SockudoException(f"Unsupported protobuf wire type: {wire}")