sockudo-python 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sockudo_python/__init__.py +47 -0
- sockudo_python/client.py +2465 -0
- sockudo_python-2.0.0.dist-info/METADATA +364 -0
- sockudo_python-2.0.0.dist-info/RECORD +7 -0
- sockudo_python-2.0.0.dist-info/WHEEL +5 -0
- sockudo_python-2.0.0.dist-info/licenses/LICENSE +21 -0
- sockudo_python-2.0.0.dist-info/top_level.txt +1 -0
sockudo_python/client.py
ADDED
|
@@ -0,0 +1,2465 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import json
|
|
6
|
+
import struct
|
|
7
|
+
import urllib.parse
|
|
8
|
+
from collections import OrderedDict
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from enum import Enum
|
|
11
|
+
from typing import (
|
|
12
|
+
Any,
|
|
13
|
+
Awaitable,
|
|
14
|
+
Callable,
|
|
15
|
+
Dict,
|
|
16
|
+
List,
|
|
17
|
+
Optional,
|
|
18
|
+
Tuple,
|
|
19
|
+
Union,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
import httpx
|
|
23
|
+
import msgpack
|
|
24
|
+
import vcdiff_decoder
|
|
25
|
+
from nacl.secret import SecretBox
|
|
26
|
+
from nacl.utils import random as nacl_random
|
|
27
|
+
from websockets.asyncio.client import connect as ws_connect
|
|
28
|
+
from websockets.exceptions import ConnectionClosed
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SockudoException(RuntimeError):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class InvalidAppKey(SockudoException):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class InvalidOptions(SockudoException):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class UnsupportedFeature(SockudoException):
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class BadEventName(SockudoException):
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class AuthFailure(SockudoException):
|
|
52
|
+
def __init__(self, status_code: Optional[int], message: str) -> None:
|
|
53
|
+
super().__init__(message)
|
|
54
|
+
self.status_code = status_code
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class DeltaFailure(SockudoException):
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ConnectionState(str, Enum):
|
|
62
|
+
INITIALIZED = "initialized"
|
|
63
|
+
CONNECTING = "connecting"
|
|
64
|
+
CONNECTED = "connected"
|
|
65
|
+
DISCONNECTED = "disconnected"
|
|
66
|
+
UNAVAILABLE = "unavailable"
|
|
67
|
+
FAILED = "failed"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class SockudoTransport(str, Enum):
|
|
71
|
+
WS = "ws"
|
|
72
|
+
WSS = "wss"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class SockudoWireFormat(str, Enum):
|
|
76
|
+
JSON = "json"
|
|
77
|
+
MESSAGEPACK = "messagepack"
|
|
78
|
+
PROTOBUF = "protobuf"
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def is_binary(self) -> bool:
|
|
82
|
+
return self is not SockudoWireFormat.JSON
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class DeltaAlgorithm(str, Enum):
|
|
86
|
+
FOSSIL = "fossil"
|
|
87
|
+
XDELTA3 = "xdelta3"
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class FilterNode:
|
|
92
|
+
op: Optional[str] = None
|
|
93
|
+
key: Optional[str] = None
|
|
94
|
+
cmp: Optional[str] = None
|
|
95
|
+
val: Optional[str] = None
|
|
96
|
+
vals: Optional[List[str]] = None
|
|
97
|
+
nodes: Optional[List["FilterNode"]] = None
|
|
98
|
+
|
|
99
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
100
|
+
return {
|
|
101
|
+
key: value
|
|
102
|
+
for key, value in {
|
|
103
|
+
"op": self.op,
|
|
104
|
+
"key": self.key,
|
|
105
|
+
"cmp": self.cmp,
|
|
106
|
+
"val": self.val,
|
|
107
|
+
"vals": self.vals,
|
|
108
|
+
"nodes": [node.to_dict() for node in self.nodes]
|
|
109
|
+
if self.nodes
|
|
110
|
+
else None,
|
|
111
|
+
}.items()
|
|
112
|
+
if value is not None
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class Filter:
|
|
117
|
+
@staticmethod
|
|
118
|
+
def eq(key: str, value: str) -> FilterNode:
|
|
119
|
+
return FilterNode(key=key, cmp="eq", val=value)
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def neq(key: str, value: str) -> FilterNode:
|
|
123
|
+
return FilterNode(key=key, cmp="neq", val=value)
|
|
124
|
+
|
|
125
|
+
@staticmethod
|
|
126
|
+
def inside(key: str, values: List[str]) -> FilterNode:
|
|
127
|
+
return FilterNode(key=key, cmp="in", vals=values)
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
def not_in(key: str, values: List[str]) -> FilterNode:
|
|
131
|
+
return FilterNode(key=key, cmp="nin", vals=values)
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def exists(key: str) -> FilterNode:
|
|
135
|
+
return FilterNode(key=key, cmp="ex")
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
def not_exists(key: str) -> FilterNode:
|
|
139
|
+
return FilterNode(key=key, cmp="nex")
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def starts_with(key: str, value: str) -> FilterNode:
|
|
143
|
+
return FilterNode(key=key, cmp="sw", val=value)
|
|
144
|
+
|
|
145
|
+
@staticmethod
|
|
146
|
+
def ends_with(key: str, value: str) -> FilterNode:
|
|
147
|
+
return FilterNode(key=key, cmp="ew", val=value)
|
|
148
|
+
|
|
149
|
+
@staticmethod
|
|
150
|
+
def contains(key: str, value: str) -> FilterNode:
|
|
151
|
+
return FilterNode(key=key, cmp="ct", val=value)
|
|
152
|
+
|
|
153
|
+
@staticmethod
|
|
154
|
+
def gt(key: str, value: str) -> FilterNode:
|
|
155
|
+
return FilterNode(key=key, cmp="gt", val=value)
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
def gte(key: str, value: str) -> FilterNode:
|
|
159
|
+
return FilterNode(key=key, cmp="gte", val=value)
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def lt(key: str, value: str) -> FilterNode:
|
|
163
|
+
return FilterNode(key=key, cmp="lt", val=value)
|
|
164
|
+
|
|
165
|
+
@staticmethod
|
|
166
|
+
def lte(key: str, value: str) -> FilterNode:
|
|
167
|
+
return FilterNode(key=key, cmp="lte", val=value)
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def and_(*nodes: FilterNode) -> FilterNode:
|
|
171
|
+
return FilterNode(op="and", nodes=list(nodes))
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def or_(*nodes: FilterNode) -> FilterNode:
|
|
175
|
+
return FilterNode(op="or", nodes=list(nodes))
|
|
176
|
+
|
|
177
|
+
@staticmethod
|
|
178
|
+
def not_(node: FilterNode) -> FilterNode:
|
|
179
|
+
return FilterNode(op="not", nodes=[node])
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def validate_filter(filter_node: FilterNode) -> Optional[str]:
|
|
183
|
+
if filter_node.op:
|
|
184
|
+
if filter_node.op not in {"and", "or", "not"}:
|
|
185
|
+
return f"Invalid logical operator: {filter_node.op}"
|
|
186
|
+
if filter_node.nodes is None:
|
|
187
|
+
return f"Logical operation '{filter_node.op}' requires nodes array"
|
|
188
|
+
if filter_node.op == "not" and len(filter_node.nodes) != 1:
|
|
189
|
+
return f"NOT operation requires exactly one child node, got {len(filter_node.nodes)}"
|
|
190
|
+
if filter_node.op in {"and", "or"} and not filter_node.nodes:
|
|
191
|
+
return (
|
|
192
|
+
f"{filter_node.op.upper()} operation requires at least one child node"
|
|
193
|
+
)
|
|
194
|
+
for index, child in enumerate(filter_node.nodes):
|
|
195
|
+
error = validate_filter(child)
|
|
196
|
+
if error is not None:
|
|
197
|
+
return f"Child node {index}: {error}"
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
if not filter_node.key:
|
|
201
|
+
return "Leaf node requires a key"
|
|
202
|
+
if not filter_node.cmp:
|
|
203
|
+
return "Leaf node requires a comparison operator"
|
|
204
|
+
if filter_node.cmp not in {
|
|
205
|
+
"eq",
|
|
206
|
+
"neq",
|
|
207
|
+
"in",
|
|
208
|
+
"nin",
|
|
209
|
+
"ex",
|
|
210
|
+
"nex",
|
|
211
|
+
"sw",
|
|
212
|
+
"ew",
|
|
213
|
+
"ct",
|
|
214
|
+
"gt",
|
|
215
|
+
"gte",
|
|
216
|
+
"lt",
|
|
217
|
+
"lte",
|
|
218
|
+
}:
|
|
219
|
+
return f"Invalid comparison operator: {filter_node.cmp}"
|
|
220
|
+
if filter_node.cmp in {"in", "nin"}:
|
|
221
|
+
if not filter_node.vals:
|
|
222
|
+
return f"{filter_node.cmp} operation requires non-empty vals array"
|
|
223
|
+
elif filter_node.cmp not in {"ex", "nex"} and not filter_node.val:
|
|
224
|
+
return f"{filter_node.cmp} operation requires a val"
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@dataclass
|
|
229
|
+
class ChannelDeltaSettings:
|
|
230
|
+
enabled: Optional[bool] = None
|
|
231
|
+
algorithm: Optional[DeltaAlgorithm] = None
|
|
232
|
+
|
|
233
|
+
def subscription_value(self) -> Any:
|
|
234
|
+
if self.enabled is None and self.algorithm is not None:
|
|
235
|
+
return self.algorithm.value
|
|
236
|
+
if self.enabled is False and self.algorithm is None:
|
|
237
|
+
return False
|
|
238
|
+
if self.enabled is True and self.algorithm is None:
|
|
239
|
+
return True
|
|
240
|
+
payload: Dict[str, Any] = {}
|
|
241
|
+
if self.enabled is not None:
|
|
242
|
+
payload["enabled"] = self.enabled
|
|
243
|
+
if self.algorithm is not None:
|
|
244
|
+
payload["algorithm"] = self.algorithm.value
|
|
245
|
+
return payload
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@dataclass
|
|
249
|
+
class MessageExtras:
|
|
250
|
+
headers: Optional[Dict[str, Any]] = None
|
|
251
|
+
ephemeral: Optional[bool] = None
|
|
252
|
+
idempotency_key: Optional[str] = None
|
|
253
|
+
echo: Optional[bool] = None
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@dataclass
|
|
257
|
+
class SubscriptionOptions:
|
|
258
|
+
filter: Optional[FilterNode] = None
|
|
259
|
+
delta: Optional[ChannelDeltaSettings] = None
|
|
260
|
+
events: Optional[List[str]] = None
|
|
261
|
+
rewind: Optional["SubscriptionRewind"] = None
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@dataclass
|
|
265
|
+
class SubscriptionRewind:
|
|
266
|
+
count: Optional[int] = None
|
|
267
|
+
seconds: Optional[int] = None
|
|
268
|
+
|
|
269
|
+
@classmethod
|
|
270
|
+
def count_messages(cls, count: int) -> "SubscriptionRewind":
|
|
271
|
+
return cls(count=count)
|
|
272
|
+
|
|
273
|
+
@classmethod
|
|
274
|
+
def seconds_back(cls, seconds: int) -> "SubscriptionRewind":
|
|
275
|
+
return cls(seconds=seconds)
|
|
276
|
+
|
|
277
|
+
def subscription_value(self) -> Any:
|
|
278
|
+
if self.count is not None:
|
|
279
|
+
return self.count
|
|
280
|
+
if self.seconds is not None:
|
|
281
|
+
return {"seconds": self.seconds}
|
|
282
|
+
raise SockudoException("SubscriptionRewind requires count or seconds")
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
@dataclass
|
|
286
|
+
class DeltaStats:
|
|
287
|
+
total_messages: int = 0
|
|
288
|
+
delta_messages: int = 0
|
|
289
|
+
full_messages: int = 0
|
|
290
|
+
total_bytes_without_compression: int = 0
|
|
291
|
+
total_bytes_with_compression: int = 0
|
|
292
|
+
errors: int = 0
|
|
293
|
+
|
|
294
|
+
@property
|
|
295
|
+
def bandwidth_saved(self) -> int:
|
|
296
|
+
return self.total_bytes_without_compression - self.total_bytes_with_compression
|
|
297
|
+
|
|
298
|
+
@property
|
|
299
|
+
def bandwidth_saved_percent(self) -> float:
|
|
300
|
+
if self.total_bytes_without_compression == 0:
|
|
301
|
+
return 0.0
|
|
302
|
+
return self.bandwidth_saved / self.total_bytes_without_compression * 100.0
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@dataclass
|
|
306
|
+
class DeltaOptions:
|
|
307
|
+
enabled: Optional[bool] = None
|
|
308
|
+
algorithms: List[DeltaAlgorithm] = field(
|
|
309
|
+
default_factory=lambda: [DeltaAlgorithm.FOSSIL, DeltaAlgorithm.XDELTA3]
|
|
310
|
+
)
|
|
311
|
+
debug: bool = False
|
|
312
|
+
on_stats: Optional[Callable[[DeltaStats], None]] = None
|
|
313
|
+
on_error: Optional[Callable[[BaseException], None]] = None
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
@dataclass
|
|
317
|
+
class PresenceMember:
|
|
318
|
+
id: str
|
|
319
|
+
info: Any
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
AuthValue = Union[str, int, float, bool]
|
|
323
|
+
HeadersProvider = Callable[[], Dict[str, str]]
|
|
324
|
+
ParamsProvider = Callable[[], Dict[str, AuthValue]]
|
|
325
|
+
ChannelAuthHandler = Callable[
|
|
326
|
+
["ChannelAuthorizationRequest"], Awaitable["ChannelAuthorizationData"]
|
|
327
|
+
]
|
|
328
|
+
UserAuthHandler = Callable[
|
|
329
|
+
["UserAuthenticationRequest"], Awaitable["UserAuthenticationData"]
|
|
330
|
+
]
|
|
331
|
+
PresenceHistoryHeadersProvider = Callable[[], Dict[str, str]]
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
@dataclass
|
|
335
|
+
class ChannelAuthorizationData:
|
|
336
|
+
auth: str
|
|
337
|
+
channel_data: Optional[str] = None
|
|
338
|
+
shared_secret: Optional[str] = None
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
@dataclass
|
|
342
|
+
class UserAuthenticationData:
|
|
343
|
+
auth: str
|
|
344
|
+
user_data: str
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@dataclass
|
|
348
|
+
class ChannelAuthorizationRequest:
|
|
349
|
+
socket_id: str
|
|
350
|
+
channel_name: str
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
@dataclass
|
|
354
|
+
class UserAuthenticationRequest:
|
|
355
|
+
socket_id: str
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
@dataclass
|
|
359
|
+
class ChannelAuthorizationOptions:
|
|
360
|
+
endpoint: str = "/sockudo/auth"
|
|
361
|
+
headers: Dict[str, str] = field(default_factory=dict)
|
|
362
|
+
params: Dict[str, AuthValue] = field(default_factory=dict)
|
|
363
|
+
headers_provider: Optional[HeadersProvider] = None
|
|
364
|
+
params_provider: Optional[ParamsProvider] = None
|
|
365
|
+
custom_handler: Optional[ChannelAuthHandler] = None
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
@dataclass
|
|
369
|
+
class UserAuthenticationOptions:
|
|
370
|
+
endpoint: str = "/sockudo/user-auth"
|
|
371
|
+
headers: Dict[str, str] = field(default_factory=dict)
|
|
372
|
+
params: Dict[str, AuthValue] = field(default_factory=dict)
|
|
373
|
+
headers_provider: Optional[HeadersProvider] = None
|
|
374
|
+
params_provider: Optional[ParamsProvider] = None
|
|
375
|
+
custom_handler: Optional[UserAuthHandler] = None
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@dataclass
|
|
379
|
+
class PresenceHistoryOptions:
|
|
380
|
+
endpoint: str
|
|
381
|
+
headers: Dict[str, str] = field(default_factory=dict)
|
|
382
|
+
headers_provider: Optional[PresenceHistoryHeadersProvider] = None
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
@dataclass
|
|
386
|
+
class SockudoOptions:
|
|
387
|
+
cluster: str
|
|
388
|
+
protocol_version: int = 2
|
|
389
|
+
activity_timeout: float = 120.0
|
|
390
|
+
force_tls: Optional[bool] = None
|
|
391
|
+
enabled_transports: Optional[List[SockudoTransport]] = None
|
|
392
|
+
disabled_transports: Optional[List[SockudoTransport]] = None
|
|
393
|
+
ws_host: Optional[str] = None
|
|
394
|
+
ws_port: int = 80
|
|
395
|
+
wss_port: int = 443
|
|
396
|
+
ws_path: str = ""
|
|
397
|
+
http_host: Optional[str] = None
|
|
398
|
+
http_port: int = 80
|
|
399
|
+
https_port: int = 443
|
|
400
|
+
http_path: str = "/sockudo"
|
|
401
|
+
pong_timeout: float = 30.0
|
|
402
|
+
unavailable_timeout: float = 10.0
|
|
403
|
+
enable_stats: bool = False
|
|
404
|
+
stats_host: str = "stats.sockudo.io"
|
|
405
|
+
timeline_params: Dict[str, AuthValue] = field(default_factory=dict)
|
|
406
|
+
channel_authorization: ChannelAuthorizationOptions = field(
|
|
407
|
+
default_factory=ChannelAuthorizationOptions
|
|
408
|
+
)
|
|
409
|
+
user_authentication: UserAuthenticationOptions = field(
|
|
410
|
+
default_factory=UserAuthenticationOptions
|
|
411
|
+
)
|
|
412
|
+
presence_history: Optional[PresenceHistoryOptions] = None
|
|
413
|
+
delta_compression: Optional[DeltaOptions] = None
|
|
414
|
+
message_deduplication: bool = True
|
|
415
|
+
message_deduplication_capacity: int = 1000
|
|
416
|
+
connection_recovery: bool = False
|
|
417
|
+
echo_messages: bool = True
|
|
418
|
+
wire_format: SockudoWireFormat = SockudoWireFormat.JSON
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
@dataclass
|
|
422
|
+
class EventMetadata:
|
|
423
|
+
user_id: Optional[str] = None
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
@dataclass
|
|
427
|
+
class PresenceHistoryParams:
|
|
428
|
+
direction: Optional[str] = None
|
|
429
|
+
limit: Optional[int] = None
|
|
430
|
+
cursor: Optional[str] = None
|
|
431
|
+
start_serial: Optional[int] = None
|
|
432
|
+
end_serial: Optional[int] = None
|
|
433
|
+
start_time_ms: Optional[int] = None
|
|
434
|
+
end_time_ms: Optional[int] = None
|
|
435
|
+
start: Optional[int] = None
|
|
436
|
+
end: Optional[int] = None
|
|
437
|
+
|
|
438
|
+
def to_payload(self) -> Dict[str, Any]:
|
|
439
|
+
payload: Dict[str, Any] = {}
|
|
440
|
+
if self.direction is not None:
|
|
441
|
+
payload["direction"] = self.direction
|
|
442
|
+
if self.limit is not None:
|
|
443
|
+
payload["limit"] = self.limit
|
|
444
|
+
if self.cursor is not None:
|
|
445
|
+
payload["cursor"] = self.cursor
|
|
446
|
+
if self.start_serial is not None:
|
|
447
|
+
payload["start_serial"] = self.start_serial
|
|
448
|
+
if self.end_serial is not None:
|
|
449
|
+
payload["end_serial"] = self.end_serial
|
|
450
|
+
if self.start_time_ms is not None:
|
|
451
|
+
payload["start_time_ms"] = self.start_time_ms
|
|
452
|
+
elif self.start is not None:
|
|
453
|
+
payload["start_time_ms"] = self.start
|
|
454
|
+
if self.end_time_ms is not None:
|
|
455
|
+
payload["end_time_ms"] = self.end_time_ms
|
|
456
|
+
elif self.end is not None:
|
|
457
|
+
payload["end_time_ms"] = self.end
|
|
458
|
+
return payload
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
@dataclass
|
|
462
|
+
class PresenceHistoryBounds:
|
|
463
|
+
start_serial: Optional[int]
|
|
464
|
+
end_serial: Optional[int]
|
|
465
|
+
start_time_ms: Optional[int]
|
|
466
|
+
end_time_ms: Optional[int]
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
@dataclass
|
|
470
|
+
class PresenceHistoryContinuity:
|
|
471
|
+
stream_id: Optional[str]
|
|
472
|
+
oldest_available_serial: Optional[int]
|
|
473
|
+
newest_available_serial: Optional[int]
|
|
474
|
+
oldest_available_published_at_ms: Optional[int]
|
|
475
|
+
newest_available_published_at_ms: Optional[int]
|
|
476
|
+
retained_events: int
|
|
477
|
+
retained_bytes: int
|
|
478
|
+
degraded: bool
|
|
479
|
+
complete: bool
|
|
480
|
+
truncated_by_retention: bool
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
@dataclass
|
|
484
|
+
class PresenceHistoryItem:
|
|
485
|
+
stream_id: str
|
|
486
|
+
serial: int
|
|
487
|
+
published_at_ms: int
|
|
488
|
+
event: str
|
|
489
|
+
cause: str
|
|
490
|
+
user_id: str
|
|
491
|
+
connection_id: Optional[str]
|
|
492
|
+
dead_node_id: Optional[str]
|
|
493
|
+
payload_size_bytes: int
|
|
494
|
+
presence_event: Dict[str, Any]
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
@dataclass
|
|
498
|
+
class PresenceSnapshotParams:
|
|
499
|
+
at_time_ms: Optional[int] = None
|
|
500
|
+
at: Optional[int] = None
|
|
501
|
+
at_serial: Optional[int] = None
|
|
502
|
+
|
|
503
|
+
def to_payload(self) -> Dict[str, Any]:
|
|
504
|
+
payload: Dict[str, Any] = {}
|
|
505
|
+
if self.at_time_ms is not None:
|
|
506
|
+
payload["at_time_ms"] = self.at_time_ms
|
|
507
|
+
elif self.at is not None:
|
|
508
|
+
payload["at_time_ms"] = self.at
|
|
509
|
+
if self.at_serial is not None:
|
|
510
|
+
payload["at_serial"] = self.at_serial
|
|
511
|
+
return payload
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
@dataclass
|
|
515
|
+
class PresenceSnapshotMember:
|
|
516
|
+
user_id: str
|
|
517
|
+
last_event: str
|
|
518
|
+
last_event_serial: int
|
|
519
|
+
last_event_at_ms: int
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
@dataclass
|
|
523
|
+
class PresenceSnapshot:
|
|
524
|
+
channel: str
|
|
525
|
+
members: List[PresenceSnapshotMember]
|
|
526
|
+
member_count: int
|
|
527
|
+
events_replayed: int
|
|
528
|
+
snapshot_serial: Optional[int]
|
|
529
|
+
snapshot_time_ms: Optional[int]
|
|
530
|
+
continuity: PresenceHistoryContinuity
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
@dataclass
|
|
534
|
+
class PresenceHistoryPage:
|
|
535
|
+
items: List[PresenceHistoryItem]
|
|
536
|
+
direction: str
|
|
537
|
+
limit: int
|
|
538
|
+
has_more: bool
|
|
539
|
+
next_cursor: Optional[str]
|
|
540
|
+
bounds: PresenceHistoryBounds
|
|
541
|
+
continuity: PresenceHistoryContinuity
|
|
542
|
+
_fetch_next: Optional[Callable[[str], Awaitable["PresenceHistoryPage"]]] = None
|
|
543
|
+
|
|
544
|
+
def has_next(self) -> bool:
|
|
545
|
+
return self.has_more and self.next_cursor is not None
|
|
546
|
+
|
|
547
|
+
async def next(self) -> "PresenceHistoryPage":
|
|
548
|
+
if not self.has_next() or self._fetch_next is None:
|
|
549
|
+
raise SockudoException("No more pages available")
|
|
550
|
+
return await self._fetch_next(self.next_cursor)
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
@dataclass
|
|
554
|
+
class SockudoEvent:
|
|
555
|
+
event: str
|
|
556
|
+
channel: Optional[str]
|
|
557
|
+
data: Any
|
|
558
|
+
user_id: Optional[str]
|
|
559
|
+
message_id: Optional[str]
|
|
560
|
+
stream_id: Optional[str]
|
|
561
|
+
raw_message: str
|
|
562
|
+
sequence: Optional[int] = None
|
|
563
|
+
conflation_key: Optional[str] = None
|
|
564
|
+
serial: Optional[int] = None
|
|
565
|
+
extras: Optional[MessageExtras] = None
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
@dataclass
|
|
569
|
+
class RecoveryPosition:
|
|
570
|
+
serial: int
|
|
571
|
+
stream_id: Optional[str] = None
|
|
572
|
+
last_message_id: Optional[str] = None
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
class ProtocolPrefix:
|
|
576
|
+
def __init__(self, version: int) -> None:
|
|
577
|
+
self.version = "2" if version >= 2 else "7"
|
|
578
|
+
self.event_prefix = "sockudo:" if version >= 2 else "pusher:"
|
|
579
|
+
self.internal_prefix = (
|
|
580
|
+
"sockudo_internal:" if version >= 2 else "pusher_internal:"
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
def event(self, name: str) -> str:
|
|
584
|
+
return f"{self.event_prefix}{name}"
|
|
585
|
+
|
|
586
|
+
def internal(self, name: str) -> str:
|
|
587
|
+
return f"{self.internal_prefix}{name}"
|
|
588
|
+
|
|
589
|
+
def is_internal_event(self, name: str) -> bool:
|
|
590
|
+
return name.startswith(self.internal_prefix)
|
|
591
|
+
|
|
592
|
+
def is_platform_event(self, name: str) -> bool:
|
|
593
|
+
return name.startswith(self.event_prefix)
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
class EventDispatcher:
|
|
597
|
+
def __init__(
|
|
598
|
+
self, failthrough: Optional[Callable[[str, Any], None]] = None
|
|
599
|
+
) -> None:
|
|
600
|
+
self._callbacks: Dict[
|
|
601
|
+
str, "OrderedDict[str, Callable[[Any, Optional[EventMetadata]], None]]"
|
|
602
|
+
] = {}
|
|
603
|
+
self._global_callbacks: "OrderedDict[str, Callable[[str, Any], None]]" = (
|
|
604
|
+
OrderedDict()
|
|
605
|
+
)
|
|
606
|
+
self._failthrough = failthrough
|
|
607
|
+
|
|
608
|
+
def bind(
|
|
609
|
+
self, event_name: str, callback: Callable[[Any, Optional[EventMetadata]], None]
|
|
610
|
+
) -> str:
|
|
611
|
+
token = base64.urlsafe_b64encode(nacl_random(9)).decode("ascii")
|
|
612
|
+
self._callbacks.setdefault(event_name, OrderedDict())[token] = callback
|
|
613
|
+
return token
|
|
614
|
+
|
|
615
|
+
def bind_global(self, callback: Callable[[str, Any], None]) -> str:
|
|
616
|
+
token = base64.urlsafe_b64encode(nacl_random(9)).decode("ascii")
|
|
617
|
+
self._global_callbacks[token] = callback
|
|
618
|
+
return token
|
|
619
|
+
|
|
620
|
+
def unbind_global(self, token: Optional[str] = None) -> None:
|
|
621
|
+
if token is None:
|
|
622
|
+
self._global_callbacks.clear()
|
|
623
|
+
return
|
|
624
|
+
self._global_callbacks.pop(token, None)
|
|
625
|
+
|
|
626
|
+
def unbind(
|
|
627
|
+
self, event_name: Optional[str] = None, token: Optional[str] = None
|
|
628
|
+
) -> None:
|
|
629
|
+
if event_name is not None and token is None:
|
|
630
|
+
self._callbacks.pop(event_name, None)
|
|
631
|
+
return
|
|
632
|
+
if event_name is not None and token is not None:
|
|
633
|
+
callbacks = self._callbacks.get(event_name)
|
|
634
|
+
if callbacks is None:
|
|
635
|
+
return
|
|
636
|
+
callbacks.pop(token, None)
|
|
637
|
+
if not callbacks:
|
|
638
|
+
self._callbacks.pop(event_name, None)
|
|
639
|
+
return
|
|
640
|
+
if token is not None:
|
|
641
|
+
for name in list(self._callbacks):
|
|
642
|
+
self._callbacks[name].pop(token, None)
|
|
643
|
+
if not self._callbacks[name]:
|
|
644
|
+
self._callbacks.pop(name, None)
|
|
645
|
+
self._global_callbacks.pop(token, None)
|
|
646
|
+
return
|
|
647
|
+
self._callbacks.clear()
|
|
648
|
+
self._global_callbacks.clear()
|
|
649
|
+
|
|
650
|
+
def emit(
|
|
651
|
+
self, event_name: str, data: Any, metadata: Optional[EventMetadata] = None
|
|
652
|
+
) -> None:
|
|
653
|
+
for callback in self._global_callbacks.values():
|
|
654
|
+
callback(event_name, data)
|
|
655
|
+
callbacks = self._callbacks.get(event_name)
|
|
656
|
+
if not callbacks:
|
|
657
|
+
if self._failthrough is not None:
|
|
658
|
+
self._failthrough(event_name, data)
|
|
659
|
+
return
|
|
660
|
+
for callback in callbacks.values():
|
|
661
|
+
callback(data, metadata)
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
class MessageDeduplicator:
|
|
665
|
+
def __init__(self, capacity: int = 1000) -> None:
|
|
666
|
+
self._capacity = capacity
|
|
667
|
+
self._seen: "OrderedDict[str, bool]" = OrderedDict()
|
|
668
|
+
|
|
669
|
+
def is_duplicate(self, message_id: str) -> bool:
|
|
670
|
+
return message_id in self._seen
|
|
671
|
+
|
|
672
|
+
def track(self, message_id: str) -> None:
|
|
673
|
+
self._seen.pop(message_id, None)
|
|
674
|
+
self._seen[message_id] = True
|
|
675
|
+
while len(self._seen) > self._capacity:
|
|
676
|
+
self._seen.popitem(last=False)
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
class FossilDelta:
|
|
680
|
+
_digits = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz~"
|
|
681
|
+
_values = {ord(ch): index for index, ch in enumerate(_digits)}
|
|
682
|
+
|
|
683
|
+
class _Reader:
|
|
684
|
+
def __init__(self, data: bytes) -> None:
|
|
685
|
+
self.data = data
|
|
686
|
+
self.position = 0
|
|
687
|
+
|
|
688
|
+
@property
|
|
689
|
+
def has_bytes(self) -> bool:
|
|
690
|
+
return self.position < len(self.data)
|
|
691
|
+
|
|
692
|
+
def byte(self) -> int:
|
|
693
|
+
if self.position >= len(self.data):
|
|
694
|
+
raise DeltaFailure("out of bounds")
|
|
695
|
+
value = self.data[self.position]
|
|
696
|
+
self.position += 1
|
|
697
|
+
return value
|
|
698
|
+
|
|
699
|
+
def character(self) -> str:
|
|
700
|
+
return chr(self.byte())
|
|
701
|
+
|
|
702
|
+
def integer(self) -> int:
|
|
703
|
+
value = 0
|
|
704
|
+
while self.has_bytes:
|
|
705
|
+
raw = self.byte()
|
|
706
|
+
mapped = FossilDelta._values.get(raw, -1)
|
|
707
|
+
if mapped < 0:
|
|
708
|
+
self.position -= 1
|
|
709
|
+
break
|
|
710
|
+
value = (value << 6) + mapped
|
|
711
|
+
return value
|
|
712
|
+
|
|
713
|
+
@staticmethod
|
|
714
|
+
def apply(base: bytes, delta: bytes) -> bytes:
|
|
715
|
+
reader = FossilDelta._Reader(delta)
|
|
716
|
+
output_size = reader.integer()
|
|
717
|
+
if reader.character() != "\n":
|
|
718
|
+
raise DeltaFailure("size integer not terminated by newline")
|
|
719
|
+
output = bytearray()
|
|
720
|
+
total = 0
|
|
721
|
+
while reader.has_bytes:
|
|
722
|
+
count = reader.integer()
|
|
723
|
+
op = reader.character()
|
|
724
|
+
if op == "@":
|
|
725
|
+
offset = reader.integer()
|
|
726
|
+
if reader.has_bytes and reader.character() != ",":
|
|
727
|
+
raise DeltaFailure("copy command not terminated by comma")
|
|
728
|
+
total += count
|
|
729
|
+
if total > output_size:
|
|
730
|
+
raise DeltaFailure("copy exceeds output file size")
|
|
731
|
+
if offset + count > len(base):
|
|
732
|
+
raise DeltaFailure("copy extends past end of input")
|
|
733
|
+
output.extend(base[offset : offset + count])
|
|
734
|
+
elif op == ":":
|
|
735
|
+
total += count
|
|
736
|
+
if total > output_size:
|
|
737
|
+
raise DeltaFailure(
|
|
738
|
+
"insert command gives an output larger than predicted"
|
|
739
|
+
)
|
|
740
|
+
if reader.position + count > len(delta):
|
|
741
|
+
raise DeltaFailure("insert count exceeds size of delta")
|
|
742
|
+
output.extend(delta[reader.position : reader.position + count])
|
|
743
|
+
reader.position += count
|
|
744
|
+
elif op == ";":
|
|
745
|
+
payload = bytes(output)
|
|
746
|
+
if count != FossilDelta._checksum(payload):
|
|
747
|
+
raise DeltaFailure("bad checksum")
|
|
748
|
+
if total != output_size:
|
|
749
|
+
raise DeltaFailure("generated size does not match predicted size")
|
|
750
|
+
return payload
|
|
751
|
+
else:
|
|
752
|
+
raise DeltaFailure("unknown delta operator")
|
|
753
|
+
raise DeltaFailure("unterminated delta")
|
|
754
|
+
|
|
755
|
+
@staticmethod
|
|
756
|
+
def _checksum(data: bytes) -> int:
|
|
757
|
+
n_hash = 16
|
|
758
|
+
sum0 = sum1 = sum2 = sum3 = 0
|
|
759
|
+
index = 0
|
|
760
|
+
remaining = len(data)
|
|
761
|
+
while remaining >= n_hash:
|
|
762
|
+
sum0 += (
|
|
763
|
+
data[index + 0] + data[index + 4] + data[index + 8] + data[index + 12]
|
|
764
|
+
)
|
|
765
|
+
sum1 += (
|
|
766
|
+
data[index + 1] + data[index + 5] + data[index + 9] + data[index + 13]
|
|
767
|
+
)
|
|
768
|
+
sum2 += (
|
|
769
|
+
data[index + 2] + data[index + 6] + data[index + 10] + data[index + 14]
|
|
770
|
+
)
|
|
771
|
+
sum3 += (
|
|
772
|
+
data[index + 3] + data[index + 7] + data[index + 11] + data[index + 15]
|
|
773
|
+
)
|
|
774
|
+
index += n_hash
|
|
775
|
+
remaining -= n_hash
|
|
776
|
+
while remaining >= 4:
|
|
777
|
+
sum0 += data[index + 0]
|
|
778
|
+
sum1 += data[index + 1]
|
|
779
|
+
sum2 += data[index + 2]
|
|
780
|
+
sum3 += data[index + 3]
|
|
781
|
+
index += 4
|
|
782
|
+
remaining -= 4
|
|
783
|
+
sum3 += (sum2 << 8) + (sum1 << 16) + (sum0 << 24)
|
|
784
|
+
if remaining == 3:
|
|
785
|
+
sum3 += data[index + 2] << 8
|
|
786
|
+
sum3 += data[index + 1] << 16
|
|
787
|
+
sum3 += data[index + 0] << 24
|
|
788
|
+
elif remaining == 2:
|
|
789
|
+
sum3 += data[index + 1] << 16
|
|
790
|
+
sum3 += data[index + 0] << 24
|
|
791
|
+
elif remaining == 1:
|
|
792
|
+
sum3 += data[index + 0] << 24
|
|
793
|
+
return sum3
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
class ProtocolCodec:
|
|
797
|
+
_messagepack_fields = [
|
|
798
|
+
"event",
|
|
799
|
+
"channel",
|
|
800
|
+
"data",
|
|
801
|
+
"name",
|
|
802
|
+
"user_id",
|
|
803
|
+
"tags",
|
|
804
|
+
"sequence",
|
|
805
|
+
"conflation_key",
|
|
806
|
+
"message_id",
|
|
807
|
+
"serial",
|
|
808
|
+
"idempotency_key",
|
|
809
|
+
"extras",
|
|
810
|
+
"__delta_seq",
|
|
811
|
+
"__conflation_key",
|
|
812
|
+
"stream_id",
|
|
813
|
+
]
|
|
814
|
+
|
|
815
|
+
@staticmethod
|
|
816
|
+
def encode_envelope(
|
|
817
|
+
envelope: Dict[str, Any], wire_format: SockudoWireFormat
|
|
818
|
+
) -> Union[str, bytes]:
|
|
819
|
+
if wire_format is SockudoWireFormat.JSON:
|
|
820
|
+
return json.dumps(envelope, separators=(",", ":"))
|
|
821
|
+
if wire_format is SockudoWireFormat.MESSAGEPACK:
|
|
822
|
+
payload = [
|
|
823
|
+
envelope.get("event"),
|
|
824
|
+
envelope.get("channel"),
|
|
825
|
+
ProtocolCodec._encode_messagepack_data(envelope.get("data")),
|
|
826
|
+
envelope.get("name"),
|
|
827
|
+
envelope.get("user_id"),
|
|
828
|
+
envelope.get("tags"),
|
|
829
|
+
envelope.get("sequence"),
|
|
830
|
+
envelope.get("conflation_key"),
|
|
831
|
+
envelope.get("message_id"),
|
|
832
|
+
envelope.get("serial"),
|
|
833
|
+
envelope.get("idempotency_key"),
|
|
834
|
+
ProtocolCodec._encode_messagepack_extras(envelope.get("extras")),
|
|
835
|
+
envelope.get("__delta_seq"),
|
|
836
|
+
envelope.get("__conflation_key"),
|
|
837
|
+
envelope.get("stream_id"),
|
|
838
|
+
]
|
|
839
|
+
return msgpack.packb(payload, use_bin_type=True)
|
|
840
|
+
return ProtocolCodec._encode_protobuf(envelope)
|
|
841
|
+
|
|
842
|
+
@staticmethod
|
|
843
|
+
def decode_event(
|
|
844
|
+
raw_message: Union[str, bytes], wire_format: SockudoWireFormat
|
|
845
|
+
) -> SockudoEvent:
|
|
846
|
+
envelope, raw_text = ProtocolCodec.decode_envelope(raw_message, wire_format)
|
|
847
|
+
raw_data = envelope.get("data")
|
|
848
|
+
data = raw_data
|
|
849
|
+
if isinstance(raw_data, str):
|
|
850
|
+
try:
|
|
851
|
+
data = json.loads(raw_data)
|
|
852
|
+
except json.JSONDecodeError:
|
|
853
|
+
data = raw_data
|
|
854
|
+
return SockudoEvent(
|
|
855
|
+
event=envelope["event"],
|
|
856
|
+
channel=envelope.get("channel"),
|
|
857
|
+
data=data,
|
|
858
|
+
user_id=envelope.get("user_id"),
|
|
859
|
+
message_id=envelope.get("message_id"),
|
|
860
|
+
stream_id=envelope.get("stream_id"),
|
|
861
|
+
raw_message=raw_text,
|
|
862
|
+
sequence=_coerce_int(envelope.get("__delta_seq", envelope.get("sequence"))),
|
|
863
|
+
conflation_key=envelope.get(
|
|
864
|
+
"__conflation_key", envelope.get("conflation_key")
|
|
865
|
+
),
|
|
866
|
+
serial=_coerce_int(envelope.get("serial")),
|
|
867
|
+
extras=ProtocolCodec._decode_extras(envelope.get("extras")),
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
@staticmethod
|
|
871
|
+
def decode_envelope(
|
|
872
|
+
raw_message: Union[str, bytes], wire_format: SockudoWireFormat
|
|
873
|
+
) -> Tuple[Dict[str, Any], str]:
|
|
874
|
+
if wire_format is SockudoWireFormat.JSON:
|
|
875
|
+
text = (
|
|
876
|
+
raw_message.decode("utf-8")
|
|
877
|
+
if isinstance(raw_message, (bytes, bytearray))
|
|
878
|
+
else raw_message
|
|
879
|
+
)
|
|
880
|
+
decoded = json.loads(text)
|
|
881
|
+
if not isinstance(decoded, dict):
|
|
882
|
+
raise SockudoException("Unable to decode event envelope")
|
|
883
|
+
return decoded, text
|
|
884
|
+
if wire_format is SockudoWireFormat.MESSAGEPACK:
|
|
885
|
+
unpacked = msgpack.unpackb(
|
|
886
|
+
raw_message
|
|
887
|
+
if isinstance(raw_message, bytes)
|
|
888
|
+
else raw_message.encode("utf-8"),
|
|
889
|
+
raw=False,
|
|
890
|
+
)
|
|
891
|
+
if isinstance(unpacked, list):
|
|
892
|
+
envelope = {}
|
|
893
|
+
for index, field in enumerate(ProtocolCodec._messagepack_fields):
|
|
894
|
+
if index < len(unpacked):
|
|
895
|
+
value = ProtocolCodec._decode_messagepack_value(unpacked[index])
|
|
896
|
+
if value is not None:
|
|
897
|
+
envelope[field] = value
|
|
898
|
+
elif isinstance(unpacked, dict):
|
|
899
|
+
envelope = {
|
|
900
|
+
str(key): ProtocolCodec._decode_messagepack_value(value)
|
|
901
|
+
for key, value in unpacked.items()
|
|
902
|
+
}
|
|
903
|
+
else:
|
|
904
|
+
raise SockudoException("Unable to decode event envelope")
|
|
905
|
+
return envelope, json.dumps(envelope, separators=(",", ":"))
|
|
906
|
+
envelope = ProtocolCodec._decode_protobuf(
|
|
907
|
+
raw_message
|
|
908
|
+
if isinstance(raw_message, bytes)
|
|
909
|
+
else raw_message.encode("utf-8")
|
|
910
|
+
)
|
|
911
|
+
return envelope, json.dumps(envelope, separators=(",", ":"))
|
|
912
|
+
|
|
913
|
+
@staticmethod
|
|
914
|
+
def _encode_messagepack_data(value: Any) -> Any:
|
|
915
|
+
if value is None:
|
|
916
|
+
return None
|
|
917
|
+
if isinstance(value, str):
|
|
918
|
+
return ["string", value]
|
|
919
|
+
return ["json", json.dumps(value, separators=(",", ":"))]
|
|
920
|
+
|
|
921
|
+
@staticmethod
|
|
922
|
+
def _encode_messagepack_extras(raw_extras: Any) -> Any:
|
|
923
|
+
extras = ProtocolCodec._decode_extras(raw_extras)
|
|
924
|
+
if extras is None:
|
|
925
|
+
return None
|
|
926
|
+
encoded: Dict[str, Any] = {}
|
|
927
|
+
if extras.headers is not None:
|
|
928
|
+
encoded_headers = {}
|
|
929
|
+
for key, value in extras.headers.items():
|
|
930
|
+
if isinstance(value, bool):
|
|
931
|
+
encoded_headers[key] = ["bool", value]
|
|
932
|
+
elif isinstance(value, (int, float)):
|
|
933
|
+
encoded_headers[key] = ["number", float(value)]
|
|
934
|
+
else:
|
|
935
|
+
encoded_headers[key] = ["string", str(value)]
|
|
936
|
+
encoded["headers"] = encoded_headers
|
|
937
|
+
if extras.ephemeral is not None:
|
|
938
|
+
encoded["ephemeral"] = extras.ephemeral
|
|
939
|
+
if extras.idempotency_key is not None:
|
|
940
|
+
encoded["idempotency_key"] = extras.idempotency_key
|
|
941
|
+
if extras.echo is not None:
|
|
942
|
+
encoded["echo"] = extras.echo
|
|
943
|
+
return encoded
|
|
944
|
+
|
|
945
|
+
@staticmethod
|
|
946
|
+
def _decode_messagepack_value(value: Any) -> Any:
|
|
947
|
+
if isinstance(value, list):
|
|
948
|
+
if len(value) == 2 and isinstance(value[0], str):
|
|
949
|
+
tag = value[0]
|
|
950
|
+
if tag in {"string", "json", "number", "bool"}:
|
|
951
|
+
return value[1]
|
|
952
|
+
return [ProtocolCodec._decode_messagepack_value(item) for item in value]
|
|
953
|
+
if isinstance(value, dict):
|
|
954
|
+
return {
|
|
955
|
+
str(key): ProtocolCodec._decode_messagepack_value(item)
|
|
956
|
+
for key, item in value.items()
|
|
957
|
+
}
|
|
958
|
+
return value
|
|
959
|
+
|
|
960
|
+
@staticmethod
|
|
961
|
+
def _decode_extras(raw_extras: Any) -> Optional[MessageExtras]:
|
|
962
|
+
if raw_extras is None:
|
|
963
|
+
return None
|
|
964
|
+
if isinstance(raw_extras, MessageExtras):
|
|
965
|
+
return raw_extras
|
|
966
|
+
if not isinstance(raw_extras, dict):
|
|
967
|
+
return None
|
|
968
|
+
headers = raw_extras.get("headers")
|
|
969
|
+
if isinstance(headers, dict):
|
|
970
|
+
decoded_headers = {}
|
|
971
|
+
for key, value in headers.items():
|
|
972
|
+
if isinstance(value, list) and len(value) == 2:
|
|
973
|
+
decoded_headers[key] = value[1]
|
|
974
|
+
else:
|
|
975
|
+
decoded_headers[key] = value
|
|
976
|
+
headers = decoded_headers
|
|
977
|
+
return MessageExtras(
|
|
978
|
+
headers=headers if isinstance(headers, dict) else None,
|
|
979
|
+
ephemeral=raw_extras.get("ephemeral"),
|
|
980
|
+
idempotency_key=raw_extras.get("idempotency_key"),
|
|
981
|
+
echo=raw_extras.get("echo"),
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
@staticmethod
|
|
985
|
+
def _encode_protobuf(envelope: Dict[str, Any]) -> bytes:
|
|
986
|
+
output = bytearray()
|
|
987
|
+
_write_string_field(output, 1, envelope.get("event"))
|
|
988
|
+
_write_string_field(output, 2, envelope.get("channel"))
|
|
989
|
+
if "data" in envelope and envelope.get("data") is not None:
|
|
990
|
+
nested = bytearray()
|
|
991
|
+
data = envelope["data"]
|
|
992
|
+
if isinstance(data, str):
|
|
993
|
+
_write_string_field(nested, 1, data)
|
|
994
|
+
else:
|
|
995
|
+
_write_string_field(nested, 3, json.dumps(data, separators=(",", ":")))
|
|
996
|
+
_write_bytes_field(output, 3, bytes(nested))
|
|
997
|
+
_write_string_field(output, 5, envelope.get("user_id"))
|
|
998
|
+
_write_uint_field(output, 7, envelope.get("sequence"))
|
|
999
|
+
_write_string_field(output, 8, envelope.get("conflation_key"))
|
|
1000
|
+
_write_string_field(output, 9, envelope.get("message_id"))
|
|
1001
|
+
_write_uint_field(output, 10, envelope.get("serial"))
|
|
1002
|
+
extras = ProtocolCodec._encode_protobuf_extras(envelope.get("extras"))
|
|
1003
|
+
if extras is not None:
|
|
1004
|
+
_write_bytes_field(output, 12, extras)
|
|
1005
|
+
_write_uint_field(output, 13, envelope.get("__delta_seq"))
|
|
1006
|
+
_write_string_field(output, 14, envelope.get("__conflation_key"))
|
|
1007
|
+
_write_string_field(output, 15, envelope.get("stream_id"))
|
|
1008
|
+
return bytes(output)
|
|
1009
|
+
|
|
1010
|
+
@staticmethod
|
|
1011
|
+
def _encode_protobuf_extras(raw_extras: Any) -> Optional[bytes]:
|
|
1012
|
+
extras = ProtocolCodec._decode_extras(raw_extras)
|
|
1013
|
+
if extras is None:
|
|
1014
|
+
return None
|
|
1015
|
+
output = bytearray()
|
|
1016
|
+
if extras.headers:
|
|
1017
|
+
for key, value in extras.headers.items():
|
|
1018
|
+
entry = bytearray()
|
|
1019
|
+
_write_string_field(entry, 1, key)
|
|
1020
|
+
value_bytes = bytearray()
|
|
1021
|
+
if isinstance(value, bool):
|
|
1022
|
+
_write_bool_field(value_bytes, 3, value)
|
|
1023
|
+
elif isinstance(value, (int, float)):
|
|
1024
|
+
_write_double_field(value_bytes, 2, float(value))
|
|
1025
|
+
else:
|
|
1026
|
+
_write_string_field(value_bytes, 1, str(value))
|
|
1027
|
+
_write_bytes_field(entry, 2, bytes(value_bytes))
|
|
1028
|
+
_write_bytes_field(output, 1, bytes(entry))
|
|
1029
|
+
_write_optional_bool_field(output, 2, extras.ephemeral)
|
|
1030
|
+
_write_string_field(output, 3, extras.idempotency_key)
|
|
1031
|
+
_write_optional_bool_field(output, 4, extras.echo)
|
|
1032
|
+
return bytes(output)
|
|
1033
|
+
|
|
1034
|
+
@staticmethod
|
|
1035
|
+
def _decode_protobuf(payload: bytes) -> Dict[str, Any]:
|
|
1036
|
+
index = 0
|
|
1037
|
+
envelope: Dict[str, Any] = {}
|
|
1038
|
+
while index < len(payload):
|
|
1039
|
+
tag, index = _read_varint(payload, index)
|
|
1040
|
+
field = tag >> 3
|
|
1041
|
+
wire = tag & 0x7
|
|
1042
|
+
if field in {1, 2, 5, 8, 9, 14, 15}:
|
|
1043
|
+
value, index = _read_length_delimited(payload, index)
|
|
1044
|
+
envelope[
|
|
1045
|
+
{
|
|
1046
|
+
1: "event",
|
|
1047
|
+
2: "channel",
|
|
1048
|
+
5: "user_id",
|
|
1049
|
+
8: "conflation_key",
|
|
1050
|
+
9: "message_id",
|
|
1051
|
+
14: "__conflation_key",
|
|
1052
|
+
15: "stream_id",
|
|
1053
|
+
}[field]
|
|
1054
|
+
] = value.decode("utf-8")
|
|
1055
|
+
elif field in {7, 10, 13}:
|
|
1056
|
+
value, index = _read_varint(payload, index)
|
|
1057
|
+
envelope[{7: "sequence", 10: "serial", 13: "__delta_seq"}[field]] = (
|
|
1058
|
+
value
|
|
1059
|
+
)
|
|
1060
|
+
elif field == 3:
|
|
1061
|
+
value, index = _read_length_delimited(payload, index)
|
|
1062
|
+
envelope["data"] = ProtocolCodec._decode_proto_data(value)
|
|
1063
|
+
elif field == 12:
|
|
1064
|
+
value, index = _read_length_delimited(payload, index)
|
|
1065
|
+
envelope["extras"] = ProtocolCodec._decode_proto_extras(value)
|
|
1066
|
+
else:
|
|
1067
|
+
index = _skip_unknown(payload, index, wire)
|
|
1068
|
+
return envelope
|
|
1069
|
+
|
|
1070
|
+
@staticmethod
|
|
1071
|
+
def _decode_proto_data(payload: bytes) -> Any:
|
|
1072
|
+
index = 0
|
|
1073
|
+
data: Dict[int, Any] = {}
|
|
1074
|
+
while index < len(payload):
|
|
1075
|
+
tag, index = _read_varint(payload, index)
|
|
1076
|
+
field = tag >> 3
|
|
1077
|
+
wire = tag & 0x7
|
|
1078
|
+
if field in {1, 3} and wire == 2:
|
|
1079
|
+
value, index = _read_length_delimited(payload, index)
|
|
1080
|
+
data[field] = value.decode("utf-8")
|
|
1081
|
+
else:
|
|
1082
|
+
index = _skip_unknown(payload, index, wire)
|
|
1083
|
+
if 1 in data:
|
|
1084
|
+
return data[1]
|
|
1085
|
+
if 3 in data:
|
|
1086
|
+
return data[3]
|
|
1087
|
+
return None
|
|
1088
|
+
|
|
1089
|
+
@staticmethod
|
|
1090
|
+
def _decode_proto_extras(payload: bytes) -> Dict[str, Any]:
|
|
1091
|
+
index = 0
|
|
1092
|
+
result: Dict[str, Any] = {}
|
|
1093
|
+
headers: Dict[str, Any] = {}
|
|
1094
|
+
while index < len(payload):
|
|
1095
|
+
tag, index = _read_varint(payload, index)
|
|
1096
|
+
field = tag >> 3
|
|
1097
|
+
wire = tag & 0x7
|
|
1098
|
+
if field == 1 and wire == 2:
|
|
1099
|
+
entry, index = _read_length_delimited(payload, index)
|
|
1100
|
+
key, value = ProtocolCodec._decode_proto_header_entry(entry)
|
|
1101
|
+
if key is not None:
|
|
1102
|
+
headers[key] = value
|
|
1103
|
+
elif field == 2 and wire == 0:
|
|
1104
|
+
value, index = _read_varint(payload, index)
|
|
1105
|
+
result["ephemeral"] = bool(value)
|
|
1106
|
+
elif field == 3 and wire == 2:
|
|
1107
|
+
value, index = _read_length_delimited(payload, index)
|
|
1108
|
+
result["idempotency_key"] = value.decode("utf-8")
|
|
1109
|
+
elif field == 4 and wire == 0:
|
|
1110
|
+
value, index = _read_varint(payload, index)
|
|
1111
|
+
result["echo"] = bool(value)
|
|
1112
|
+
else:
|
|
1113
|
+
index = _skip_unknown(payload, index, wire)
|
|
1114
|
+
if headers:
|
|
1115
|
+
result["headers"] = headers
|
|
1116
|
+
return result
|
|
1117
|
+
|
|
1118
|
+
@staticmethod
|
|
1119
|
+
def _decode_proto_header_entry(payload: bytes) -> Tuple[Optional[str], Any]:
|
|
1120
|
+
index = 0
|
|
1121
|
+
key = None
|
|
1122
|
+
value = None
|
|
1123
|
+
while index < len(payload):
|
|
1124
|
+
tag, index = _read_varint(payload, index)
|
|
1125
|
+
field = tag >> 3
|
|
1126
|
+
wire = tag & 0x7
|
|
1127
|
+
if field == 1 and wire == 2:
|
|
1128
|
+
raw, index = _read_length_delimited(payload, index)
|
|
1129
|
+
key = raw.decode("utf-8")
|
|
1130
|
+
elif field == 2 and wire == 2:
|
|
1131
|
+
raw, index = _read_length_delimited(payload, index)
|
|
1132
|
+
value = ProtocolCodec._decode_proto_extra_value(raw)
|
|
1133
|
+
else:
|
|
1134
|
+
index = _skip_unknown(payload, index, wire)
|
|
1135
|
+
return key, value
|
|
1136
|
+
|
|
1137
|
+
@staticmethod
|
|
1138
|
+
def _decode_proto_extra_value(payload: bytes) -> Any:
|
|
1139
|
+
index = 0
|
|
1140
|
+
while index < len(payload):
|
|
1141
|
+
tag, index = _read_varint(payload, index)
|
|
1142
|
+
field = tag >> 3
|
|
1143
|
+
wire = tag & 0x7
|
|
1144
|
+
if field == 1 and wire == 2:
|
|
1145
|
+
raw, index = _read_length_delimited(payload, index)
|
|
1146
|
+
return raw.decode("utf-8")
|
|
1147
|
+
if field == 2 and wire == 1:
|
|
1148
|
+
return struct.unpack("<d", payload[index : index + 8])[0]
|
|
1149
|
+
if field == 3 and wire == 0:
|
|
1150
|
+
raw, index = _read_varint(payload, index)
|
|
1151
|
+
return bool(raw)
|
|
1152
|
+
index = _skip_unknown(payload, index, wire)
|
|
1153
|
+
return None
|
|
1154
|
+
|
|
1155
|
+
|
|
1156
|
+
class DeltaCompressionManager:
|
|
1157
|
+
def __init__(
|
|
1158
|
+
self,
|
|
1159
|
+
options: DeltaOptions,
|
|
1160
|
+
send_event: Callable[[str, Any, Optional[str]], Awaitable[bool]],
|
|
1161
|
+
prefix: ProtocolPrefix,
|
|
1162
|
+
) -> None:
|
|
1163
|
+
self._options = options
|
|
1164
|
+
self._send_event = send_event
|
|
1165
|
+
self._prefix = prefix
|
|
1166
|
+
self._enabled = False
|
|
1167
|
+
self._default_algorithm = DeltaAlgorithm.FOSSIL
|
|
1168
|
+
self._stats = DeltaStats()
|
|
1169
|
+
self._channel_states: Dict[str, Dict[str, Any]] = {}
|
|
1170
|
+
|
|
1171
|
+
async def enable(self) -> None:
|
|
1172
|
+
if self._enabled:
|
|
1173
|
+
return
|
|
1174
|
+
await self._send_event(
|
|
1175
|
+
self._prefix.event("enable_delta_compression"),
|
|
1176
|
+
{"algorithms": [algorithm.value for algorithm in self._options.algorithms]},
|
|
1177
|
+
None,
|
|
1178
|
+
)
|
|
1179
|
+
|
|
1180
|
+
def handle_enabled(self, data: Any) -> None:
|
|
1181
|
+
payload = data if isinstance(data, dict) else {}
|
|
1182
|
+
self._enabled = payload.get("enabled", True)
|
|
1183
|
+
if "algorithm" in payload:
|
|
1184
|
+
try:
|
|
1185
|
+
self._default_algorithm = DeltaAlgorithm(payload["algorithm"])
|
|
1186
|
+
except ValueError:
|
|
1187
|
+
pass
|
|
1188
|
+
|
|
1189
|
+
def handle_cache_sync(self, channel: str, data: Any) -> None:
|
|
1190
|
+
payload = data if isinstance(data, dict) else {}
|
|
1191
|
+
self._channel_states[channel] = {
|
|
1192
|
+
"conflation_key": payload.get("conflation_key"),
|
|
1193
|
+
"states": payload.get("states", {}),
|
|
1194
|
+
"base_message": None,
|
|
1195
|
+
}
|
|
1196
|
+
|
|
1197
|
+
async def handle_delta_message(
|
|
1198
|
+
self, channel: str, data: Any
|
|
1199
|
+
) -> Optional[SockudoEvent]:
|
|
1200
|
+
payload = data if isinstance(data, dict) else {}
|
|
1201
|
+
event_name = payload.get("event")
|
|
1202
|
+
delta_payload = payload.get("delta")
|
|
1203
|
+
if not isinstance(event_name, str) or not isinstance(delta_payload, str):
|
|
1204
|
+
return None
|
|
1205
|
+
algorithm = payload.get("algorithm", self._default_algorithm.value)
|
|
1206
|
+
sequence = _coerce_int(payload.get("seq"))
|
|
1207
|
+
base_state = self._channel_states.get(channel)
|
|
1208
|
+
if base_state is None or base_state.get("base_message") is None:
|
|
1209
|
+
await self._send_event(
|
|
1210
|
+
self._prefix.event("delta_sync_error"), {"channel": channel}, None
|
|
1211
|
+
)
|
|
1212
|
+
self._channel_states.pop(channel, None)
|
|
1213
|
+
return None
|
|
1214
|
+
try:
|
|
1215
|
+
delta_bytes = base64.b64decode(delta_payload)
|
|
1216
|
+
if algorithm == DeltaAlgorithm.XDELTA3.value:
|
|
1217
|
+
reconstructed = vcdiff_decoder.decode(
|
|
1218
|
+
base_state["base_message"].encode("utf-8"), delta_bytes
|
|
1219
|
+
).decode("utf-8")
|
|
1220
|
+
else:
|
|
1221
|
+
reconstructed = FossilDelta.apply(
|
|
1222
|
+
base_state["base_message"].encode("utf-8"), delta_bytes
|
|
1223
|
+
).decode("utf-8")
|
|
1224
|
+
parsed = json.loads(reconstructed)
|
|
1225
|
+
event_data = (
|
|
1226
|
+
parsed.get("data")
|
|
1227
|
+
if isinstance(parsed, dict) and "data" in parsed
|
|
1228
|
+
else parsed
|
|
1229
|
+
)
|
|
1230
|
+
self.handle_full_message(
|
|
1231
|
+
channel, reconstructed, sequence, payload.get("conflation_key")
|
|
1232
|
+
)
|
|
1233
|
+
self._stats.delta_messages += 1
|
|
1234
|
+
self._stats.total_messages += 1
|
|
1235
|
+
self._stats.total_bytes_without_compression += len(reconstructed)
|
|
1236
|
+
self._stats.total_bytes_with_compression += len(delta_bytes)
|
|
1237
|
+
if self._options.on_stats:
|
|
1238
|
+
self._options.on_stats(self._stats)
|
|
1239
|
+
return SockudoEvent(
|
|
1240
|
+
event=event_name,
|
|
1241
|
+
channel=channel,
|
|
1242
|
+
data=event_data,
|
|
1243
|
+
user_id=None,
|
|
1244
|
+
message_id=None,
|
|
1245
|
+
raw_message=reconstructed,
|
|
1246
|
+
sequence=sequence,
|
|
1247
|
+
conflation_key=payload.get("conflation_key"),
|
|
1248
|
+
)
|
|
1249
|
+
except BaseException as exc:
|
|
1250
|
+
self._stats.errors += 1
|
|
1251
|
+
if self._options.on_error:
|
|
1252
|
+
self._options.on_error(exc)
|
|
1253
|
+
return None
|
|
1254
|
+
|
|
1255
|
+
def handle_full_message(
|
|
1256
|
+
self,
|
|
1257
|
+
channel: str,
|
|
1258
|
+
raw_message: str,
|
|
1259
|
+
sequence: Optional[int],
|
|
1260
|
+
conflation_key: Optional[str],
|
|
1261
|
+
) -> None:
|
|
1262
|
+
self._channel_states.setdefault(channel, {})["base_message"] = raw_message
|
|
1263
|
+
self._stats.full_messages += 1
|
|
1264
|
+
self._stats.total_messages += 1
|
|
1265
|
+
self._stats.total_bytes_without_compression += len(raw_message)
|
|
1266
|
+
self._stats.total_bytes_with_compression += len(raw_message)
|
|
1267
|
+
if self._options.on_stats:
|
|
1268
|
+
self._options.on_stats(self._stats)
|
|
1269
|
+
|
|
1270
|
+
def get_stats(self) -> DeltaStats:
|
|
1271
|
+
return self._stats
|
|
1272
|
+
|
|
1273
|
+
def reset_stats(self) -> None:
|
|
1274
|
+
self._stats = DeltaStats()
|
|
1275
|
+
|
|
1276
|
+
def clear_channel_state(self, channel: str) -> None:
|
|
1277
|
+
self._channel_states.pop(channel, None)
|
|
1278
|
+
|
|
1279
|
+
|
|
1280
|
+
class PresenceMembers:
|
|
1281
|
+
def __init__(self) -> None:
|
|
1282
|
+
self._members: Dict[str, Any] = {}
|
|
1283
|
+
self.count: int = 0
|
|
1284
|
+
self.my_id: Optional[str] = None
|
|
1285
|
+
self.me: Optional[PresenceMember] = None
|
|
1286
|
+
|
|
1287
|
+
def member(self, member_id: str) -> Optional[PresenceMember]:
|
|
1288
|
+
if member_id not in self._members:
|
|
1289
|
+
return None
|
|
1290
|
+
return PresenceMember(member_id, self._members[member_id])
|
|
1291
|
+
|
|
1292
|
+
def remember_my_id(self, member_id: str) -> None:
|
|
1293
|
+
self.my_id = member_id
|
|
1294
|
+
|
|
1295
|
+
def apply_subscription_data(self, data: Dict[str, Any]) -> None:
|
|
1296
|
+
presence = data.get("presence", {})
|
|
1297
|
+
hash_data = presence.get("hash", {}) if isinstance(presence, dict) else {}
|
|
1298
|
+
self._members = dict(hash_data)
|
|
1299
|
+
self.count = (
|
|
1300
|
+
int(presence.get("count", len(self._members)))
|
|
1301
|
+
if isinstance(presence, dict)
|
|
1302
|
+
else len(self._members)
|
|
1303
|
+
)
|
|
1304
|
+
self.me = self.member(self.my_id) if self.my_id else None
|
|
1305
|
+
|
|
1306
|
+
def add(self, data: Dict[str, Any]) -> Optional[PresenceMember]:
|
|
1307
|
+
user_id = data.get("user_id")
|
|
1308
|
+
if not isinstance(user_id, str):
|
|
1309
|
+
return None
|
|
1310
|
+
if user_id not in self._members:
|
|
1311
|
+
self.count += 1
|
|
1312
|
+
self._members[user_id] = data.get("user_info")
|
|
1313
|
+
return PresenceMember(user_id, self._members[user_id])
|
|
1314
|
+
|
|
1315
|
+
def remove(self, data: Dict[str, Any]) -> Optional[PresenceMember]:
|
|
1316
|
+
user_id = data.get("user_id")
|
|
1317
|
+
if not isinstance(user_id, str) or user_id not in self._members:
|
|
1318
|
+
return None
|
|
1319
|
+
info = self._members.pop(user_id)
|
|
1320
|
+
self.count = max(0, self.count - 1)
|
|
1321
|
+
return PresenceMember(user_id, info)
|
|
1322
|
+
|
|
1323
|
+
def reset(self) -> None:
|
|
1324
|
+
self._members.clear()
|
|
1325
|
+
self.count = 0
|
|
1326
|
+
self.my_id = None
|
|
1327
|
+
self.me = None
|
|
1328
|
+
|
|
1329
|
+
|
|
1330
|
+
class SockudoChannel:
|
|
1331
|
+
def __init__(self, name: str, client: "SockudoClient") -> None:
|
|
1332
|
+
self.name = name
|
|
1333
|
+
self.client = client
|
|
1334
|
+
self.dispatcher = EventDispatcher()
|
|
1335
|
+
self.is_subscribed = False
|
|
1336
|
+
self.subscription_pending = False
|
|
1337
|
+
self.subscription_cancelled = False
|
|
1338
|
+
self.subscription_count: Optional[int] = None
|
|
1339
|
+
self.filter: Optional[FilterNode] = None
|
|
1340
|
+
self.delta_settings: Optional[ChannelDeltaSettings] = None
|
|
1341
|
+
self.events_filter: Optional[List[str]] = None
|
|
1342
|
+
self.rewind: Optional[SubscriptionRewind] = None
|
|
1343
|
+
|
|
1344
|
+
def bind(
|
|
1345
|
+
self, event_name: str, callback: Callable[[Any, Optional[EventMetadata]], None]
|
|
1346
|
+
) -> str:
|
|
1347
|
+
return self.dispatcher.bind(event_name, callback)
|
|
1348
|
+
|
|
1349
|
+
def bind_global(self, callback: Callable[[str, Any], None]) -> str:
|
|
1350
|
+
return self.dispatcher.bind_global(callback)
|
|
1351
|
+
|
|
1352
|
+
def unbind(
|
|
1353
|
+
self, event_name: Optional[str] = None, token: Optional[str] = None
|
|
1354
|
+
) -> None:
|
|
1355
|
+
self.dispatcher.unbind(event_name, token)
|
|
1356
|
+
|
|
1357
|
+
async def trigger(self, event: str, data: Any) -> bool:
|
|
1358
|
+
if not event.startswith("client-"):
|
|
1359
|
+
raise BadEventName(f"Event '{event}' does not start with 'client-'")
|
|
1360
|
+
return await self.client.send_event(event, data, self.name)
|
|
1361
|
+
|
|
1362
|
+
async def authorize(self, socket_id: str) -> ChannelAuthorizationData:
|
|
1363
|
+
return ChannelAuthorizationData(auth="")
|
|
1364
|
+
|
|
1365
|
+
def subscribe_if_possible(self) -> None:
|
|
1366
|
+
if self.subscription_pending and self.subscription_cancelled:
|
|
1367
|
+
self.subscription_cancelled = False
|
|
1368
|
+
elif (
|
|
1369
|
+
not self.subscription_pending
|
|
1370
|
+
and self.client.connection_state is ConnectionState.CONNECTED
|
|
1371
|
+
):
|
|
1372
|
+
asyncio.create_task(self.subscribe())
|
|
1373
|
+
|
|
1374
|
+
async def subscribe(self) -> None:
|
|
1375
|
+
if self.is_subscribed:
|
|
1376
|
+
return
|
|
1377
|
+
self.subscription_pending = True
|
|
1378
|
+
self.subscription_cancelled = False
|
|
1379
|
+
try:
|
|
1380
|
+
auth = await self.authorize(self.client.socket_id or "")
|
|
1381
|
+
payload: Dict[str, Any] = {"auth": auth.auth, "channel": self.name}
|
|
1382
|
+
if auth.channel_data is not None:
|
|
1383
|
+
payload["channel_data"] = auth.channel_data
|
|
1384
|
+
if self.filter is not None:
|
|
1385
|
+
payload["tags_filter"] = self.filter.to_dict()
|
|
1386
|
+
if self.delta_settings is not None:
|
|
1387
|
+
payload["delta"] = self.delta_settings.subscription_value()
|
|
1388
|
+
if self.events_filter is not None:
|
|
1389
|
+
payload["events"] = self.events_filter
|
|
1390
|
+
if self.rewind is not None:
|
|
1391
|
+
payload["rewind"] = self.rewind.subscription_value()
|
|
1392
|
+
await self.client.send_event(
|
|
1393
|
+
self.client.prefix.event("subscribe"), payload, None
|
|
1394
|
+
)
|
|
1395
|
+
except BaseException as exc:
|
|
1396
|
+
self.subscription_pending = False
|
|
1397
|
+
self.dispatcher.emit(
|
|
1398
|
+
self.client.prefix.event("subscription_error"),
|
|
1399
|
+
{"type": "AuthError", "error": str(exc)},
|
|
1400
|
+
)
|
|
1401
|
+
|
|
1402
|
+
async def unsubscribe(self) -> None:
|
|
1403
|
+
self.is_subscribed = False
|
|
1404
|
+
await self.client.send_event(
|
|
1405
|
+
self.client.prefix.event("unsubscribe"), {"channel": self.name}, None
|
|
1406
|
+
)
|
|
1407
|
+
|
|
1408
|
+
def disconnect(self) -> None:
|
|
1409
|
+
self.is_subscribed = False
|
|
1410
|
+
self.subscription_pending = False
|
|
1411
|
+
|
|
1412
|
+
def handle(self, event: SockudoEvent) -> None:
|
|
1413
|
+
p = self.client.prefix
|
|
1414
|
+
if event.event == p.internal("subscription_succeeded"):
|
|
1415
|
+
self.subscription_pending = False
|
|
1416
|
+
self.is_subscribed = True
|
|
1417
|
+
if self.subscription_cancelled:
|
|
1418
|
+
asyncio.create_task(self.client.unsubscribe(self.name))
|
|
1419
|
+
else:
|
|
1420
|
+
self.dispatcher.emit(p.event("subscription_succeeded"), event.data)
|
|
1421
|
+
elif event.event == p.internal("subscription_count"):
|
|
1422
|
+
if isinstance(event.data, dict):
|
|
1423
|
+
self.subscription_count = _coerce_int(
|
|
1424
|
+
event.data.get("subscription_count")
|
|
1425
|
+
)
|
|
1426
|
+
self.dispatcher.emit(p.event("subscription_count"), event.data)
|
|
1427
|
+
elif not p.is_internal_event(event.event):
|
|
1428
|
+
self.dispatcher.emit(
|
|
1429
|
+
event.event, event.data, EventMetadata(user_id=event.user_id)
|
|
1430
|
+
)
|
|
1431
|
+
|
|
1432
|
+
|
|
1433
|
+
class PrivateChannel(SockudoChannel):
|
|
1434
|
+
async def authorize(self, socket_id: str) -> ChannelAuthorizationData:
|
|
1435
|
+
return await self.client.config.authorize_channel(
|
|
1436
|
+
ChannelAuthorizationRequest(socket_id, self.name)
|
|
1437
|
+
)
|
|
1438
|
+
|
|
1439
|
+
|
|
1440
|
+
class PresenceChannel(PrivateChannel):
|
|
1441
|
+
def __init__(self, name: str, client: "SockudoClient") -> None:
|
|
1442
|
+
super().__init__(name, client)
|
|
1443
|
+
self.members = PresenceMembers()
|
|
1444
|
+
|
|
1445
|
+
async def authorize(self, socket_id: str) -> ChannelAuthorizationData:
|
|
1446
|
+
response = await super().authorize(socket_id)
|
|
1447
|
+
if response.channel_data:
|
|
1448
|
+
parsed = json.loads(response.channel_data)
|
|
1449
|
+
if isinstance(parsed, dict) and isinstance(parsed.get("user_id"), str):
|
|
1450
|
+
self.members.remember_my_id(parsed["user_id"])
|
|
1451
|
+
return response
|
|
1452
|
+
if self.client.user.user_id:
|
|
1453
|
+
self.members.remember_my_id(self.client.user.user_id)
|
|
1454
|
+
return response
|
|
1455
|
+
raise AuthFailure(
|
|
1456
|
+
None, f"Invalid auth response for presence channel '{self.name}'"
|
|
1457
|
+
)
|
|
1458
|
+
|
|
1459
|
+
def handle(self, event: SockudoEvent) -> None:
|
|
1460
|
+
p = self.client.prefix
|
|
1461
|
+
if event.event == p.internal("subscription_succeeded"):
|
|
1462
|
+
self.subscription_pending = False
|
|
1463
|
+
self.is_subscribed = True
|
|
1464
|
+
payload = event.data if isinstance(event.data, dict) else {}
|
|
1465
|
+
self.members.apply_subscription_data(payload)
|
|
1466
|
+
self.dispatcher.emit(p.event("subscription_succeeded"), self.members)
|
|
1467
|
+
elif event.event == p.internal("member_added") and isinstance(event.data, dict):
|
|
1468
|
+
member = self.members.add(event.data)
|
|
1469
|
+
if member is not None:
|
|
1470
|
+
self.dispatcher.emit(p.event("member_added"), member)
|
|
1471
|
+
elif event.event == p.internal("member_removed") and isinstance(
|
|
1472
|
+
event.data, dict
|
|
1473
|
+
):
|
|
1474
|
+
member = self.members.remove(event.data)
|
|
1475
|
+
if member is not None:
|
|
1476
|
+
self.dispatcher.emit(p.event("member_removed"), member)
|
|
1477
|
+
else:
|
|
1478
|
+
super().handle(event)
|
|
1479
|
+
|
|
1480
|
+
def disconnect(self) -> None:
|
|
1481
|
+
self.members.reset()
|
|
1482
|
+
super().disconnect()
|
|
1483
|
+
|
|
1484
|
+
async def history(
|
|
1485
|
+
self, params: Optional[PresenceHistoryParams] = None
|
|
1486
|
+
) -> PresenceHistoryPage:
|
|
1487
|
+
return await self.client.config.fetch_presence_history(
|
|
1488
|
+
self.name, params or PresenceHistoryParams()
|
|
1489
|
+
)
|
|
1490
|
+
|
|
1491
|
+
async def snapshot(
|
|
1492
|
+
self, params: Optional[PresenceSnapshotParams] = None
|
|
1493
|
+
) -> PresenceSnapshot:
|
|
1494
|
+
return await self.client.config.fetch_presence_snapshot(
|
|
1495
|
+
self.name, params or PresenceSnapshotParams()
|
|
1496
|
+
)
|
|
1497
|
+
|
|
1498
|
+
|
|
1499
|
+
class EncryptedChannel(PrivateChannel):
|
|
1500
|
+
def __init__(self, name: str, client: "SockudoClient") -> None:
|
|
1501
|
+
super().__init__(name, client)
|
|
1502
|
+
self.shared_secret: Optional[bytes] = None
|
|
1503
|
+
|
|
1504
|
+
async def authorize(self, socket_id: str) -> ChannelAuthorizationData:
|
|
1505
|
+
response = await super().authorize(socket_id)
|
|
1506
|
+
if not response.shared_secret:
|
|
1507
|
+
raise AuthFailure(
|
|
1508
|
+
None,
|
|
1509
|
+
f"No shared_secret key in auth payload for encrypted channel: {self.name}",
|
|
1510
|
+
)
|
|
1511
|
+
self.shared_secret = base64.b64decode(response.shared_secret)
|
|
1512
|
+
return ChannelAuthorizationData(
|
|
1513
|
+
auth=response.auth, channel_data=response.channel_data
|
|
1514
|
+
)
|
|
1515
|
+
|
|
1516
|
+
async def trigger(self, event: str, data: Any) -> bool:
|
|
1517
|
+
raise UnsupportedFeature(
|
|
1518
|
+
"Client events are not currently supported for encrypted channels"
|
|
1519
|
+
)
|
|
1520
|
+
|
|
1521
|
+
def handle(self, event: SockudoEvent) -> None:
|
|
1522
|
+
if self.client.prefix.is_internal_event(
|
|
1523
|
+
event.event
|
|
1524
|
+
) or self.client.prefix.is_platform_event(event.event):
|
|
1525
|
+
super().handle(event)
|
|
1526
|
+
return
|
|
1527
|
+
if self.shared_secret is None or not isinstance(event.data, dict):
|
|
1528
|
+
return
|
|
1529
|
+
cipher_text = event.data.get("ciphertext")
|
|
1530
|
+
nonce = event.data.get("nonce")
|
|
1531
|
+
if not isinstance(cipher_text, str) or not isinstance(nonce, str):
|
|
1532
|
+
return
|
|
1533
|
+
box = SecretBox(self.shared_secret)
|
|
1534
|
+
decrypted = box.decrypt(
|
|
1535
|
+
base64.b64decode(cipher_text), base64.b64decode(nonce)
|
|
1536
|
+
).decode("utf-8")
|
|
1537
|
+
parsed = json.loads(decrypted)
|
|
1538
|
+
self.dispatcher.emit(event.event, parsed, EventMetadata(user_id=event.user_id))
|
|
1539
|
+
|
|
1540
|
+
|
|
1541
|
+
class _ResolvedConfiguration:
|
|
1542
|
+
def __init__(self, options: SockudoOptions) -> None:
|
|
1543
|
+
self.cluster = options.cluster
|
|
1544
|
+
self.activity_timeout = options.activity_timeout
|
|
1545
|
+
self.use_tls = options.force_tls is not False
|
|
1546
|
+
self.ws_host = options.ws_host or f"ws-{options.cluster}.sockudo.io"
|
|
1547
|
+
self.ws_port = options.ws_port
|
|
1548
|
+
self.wss_port = options.wss_port
|
|
1549
|
+
self.ws_path = options.ws_path
|
|
1550
|
+
self.http_host = options.http_host or f"sockjs-{options.cluster}.sockudo.io"
|
|
1551
|
+
self.http_port = options.http_port
|
|
1552
|
+
self.https_port = options.https_port
|
|
1553
|
+
self.http_path = options.http_path
|
|
1554
|
+
self.pong_timeout = options.pong_timeout
|
|
1555
|
+
self.unavailable_timeout = options.unavailable_timeout
|
|
1556
|
+
self.enabled_transports = options.enabled_transports
|
|
1557
|
+
self.disabled_transports = options.disabled_transports
|
|
1558
|
+
self.channel_options = options.channel_authorization
|
|
1559
|
+
self.user_options = options.user_authentication
|
|
1560
|
+
self.presence_history = options.presence_history
|
|
1561
|
+
self._http_client = httpx.AsyncClient()
|
|
1562
|
+
|
|
1563
|
+
async def authorize_channel(
|
|
1564
|
+
self, request: ChannelAuthorizationRequest
|
|
1565
|
+
) -> ChannelAuthorizationData:
|
|
1566
|
+
if self.channel_options.custom_handler is not None:
|
|
1567
|
+
return await self.channel_options.custom_handler(request)
|
|
1568
|
+
params = dict(self.channel_options.params)
|
|
1569
|
+
if self.channel_options.params_provider:
|
|
1570
|
+
params.update(self.channel_options.params_provider())
|
|
1571
|
+
params["socket_id"] = request.socket_id
|
|
1572
|
+
params["channel_name"] = request.channel_name
|
|
1573
|
+
headers = dict(self.channel_options.headers)
|
|
1574
|
+
if self.channel_options.headers_provider:
|
|
1575
|
+
headers.update(self.channel_options.headers_provider())
|
|
1576
|
+
payload = await self._perform_auth_request(
|
|
1577
|
+
self.channel_options.endpoint, headers, params
|
|
1578
|
+
)
|
|
1579
|
+
auth = payload.get("auth")
|
|
1580
|
+
if not isinstance(auth, str):
|
|
1581
|
+
raise AuthFailure(200, "JSON returned from auth endpoint was invalid")
|
|
1582
|
+
return ChannelAuthorizationData(
|
|
1583
|
+
auth=auth,
|
|
1584
|
+
channel_data=payload.get("channel_data"),
|
|
1585
|
+
shared_secret=payload.get("shared_secret"),
|
|
1586
|
+
)
|
|
1587
|
+
|
|
1588
|
+
async def authenticate_user(
|
|
1589
|
+
self, request: UserAuthenticationRequest
|
|
1590
|
+
) -> UserAuthenticationData:
|
|
1591
|
+
if self.user_options.custom_handler is not None:
|
|
1592
|
+
return await self.user_options.custom_handler(request)
|
|
1593
|
+
params = dict(self.user_options.params)
|
|
1594
|
+
if self.user_options.params_provider:
|
|
1595
|
+
params.update(self.user_options.params_provider())
|
|
1596
|
+
params["socket_id"] = request.socket_id
|
|
1597
|
+
headers = dict(self.user_options.headers)
|
|
1598
|
+
if self.user_options.headers_provider:
|
|
1599
|
+
headers.update(self.user_options.headers_provider())
|
|
1600
|
+
payload = await self._perform_auth_request(
|
|
1601
|
+
self.user_options.endpoint, headers, params
|
|
1602
|
+
)
|
|
1603
|
+
auth = payload.get("auth")
|
|
1604
|
+
user_data = payload.get("user_data")
|
|
1605
|
+
if not isinstance(auth, str) or not isinstance(user_data, str):
|
|
1606
|
+
raise AuthFailure(200, "JSON returned from auth endpoint was invalid")
|
|
1607
|
+
return UserAuthenticationData(auth=auth, user_data=user_data)
|
|
1608
|
+
|
|
1609
|
+
async def close(self) -> None:
|
|
1610
|
+
await self._http_client.aclose()
|
|
1611
|
+
|
|
1612
|
+
async def fetch_presence_history(
|
|
1613
|
+
self, channel_name: str, params: PresenceHistoryParams
|
|
1614
|
+
) -> PresenceHistoryPage:
|
|
1615
|
+
config = self.presence_history
|
|
1616
|
+
if config is None:
|
|
1617
|
+
raise UnsupportedFeature(
|
|
1618
|
+
"presence_history.endpoint must be configured to use presence.history(). "
|
|
1619
|
+
"This endpoint should proxy requests to the Sockudo server REST API."
|
|
1620
|
+
)
|
|
1621
|
+
|
|
1622
|
+
payload = await self._perform_presence_history_request(
|
|
1623
|
+
config.endpoint,
|
|
1624
|
+
config.headers,
|
|
1625
|
+
config.headers_provider,
|
|
1626
|
+
channel_name,
|
|
1627
|
+
params.to_payload(),
|
|
1628
|
+
"history",
|
|
1629
|
+
)
|
|
1630
|
+
return self._decode_presence_history_page(
|
|
1631
|
+
payload,
|
|
1632
|
+
lambda cursor: self.fetch_presence_history(
|
|
1633
|
+
channel_name,
|
|
1634
|
+
PresenceHistoryParams(
|
|
1635
|
+
direction=params.direction,
|
|
1636
|
+
limit=params.limit,
|
|
1637
|
+
cursor=cursor,
|
|
1638
|
+
start_serial=params.start_serial,
|
|
1639
|
+
end_serial=params.end_serial,
|
|
1640
|
+
start_time_ms=params.start_time_ms,
|
|
1641
|
+
end_time_ms=params.end_time_ms,
|
|
1642
|
+
start=params.start,
|
|
1643
|
+
end=params.end,
|
|
1644
|
+
),
|
|
1645
|
+
),
|
|
1646
|
+
)
|
|
1647
|
+
|
|
1648
|
+
async def fetch_presence_snapshot(
|
|
1649
|
+
self, channel_name: str, params: PresenceSnapshotParams
|
|
1650
|
+
) -> PresenceSnapshot:
|
|
1651
|
+
config = self.presence_history
|
|
1652
|
+
if config is None:
|
|
1653
|
+
raise UnsupportedFeature(
|
|
1654
|
+
"presence_history.endpoint must be configured to use presence.snapshot(). "
|
|
1655
|
+
"This endpoint should proxy requests to the Sockudo server REST API."
|
|
1656
|
+
)
|
|
1657
|
+
|
|
1658
|
+
payload = await self._perform_presence_history_request(
|
|
1659
|
+
config.endpoint,
|
|
1660
|
+
config.headers,
|
|
1661
|
+
config.headers_provider,
|
|
1662
|
+
channel_name,
|
|
1663
|
+
params.to_payload(),
|
|
1664
|
+
"snapshot",
|
|
1665
|
+
)
|
|
1666
|
+
return self._decode_presence_snapshot(payload)
|
|
1667
|
+
|
|
1668
|
+
async def _perform_auth_request(
|
|
1669
|
+
self, endpoint: str, headers: Dict[str, str], params: Dict[str, AuthValue]
|
|
1670
|
+
) -> Dict[str, Any]:
|
|
1671
|
+
response = await self._http_client.post(
|
|
1672
|
+
endpoint,
|
|
1673
|
+
headers=headers,
|
|
1674
|
+
content=urllib.parse.urlencode(
|
|
1675
|
+
{
|
|
1676
|
+
key: str(value).lower() if isinstance(value, bool) else value
|
|
1677
|
+
for key, value in params.items()
|
|
1678
|
+
}
|
|
1679
|
+
),
|
|
1680
|
+
)
|
|
1681
|
+
if response.status_code >= 400:
|
|
1682
|
+
raise AuthFailure(
|
|
1683
|
+
response.status_code,
|
|
1684
|
+
f"Could not get auth info from endpoint, status: {response.status_code}",
|
|
1685
|
+
)
|
|
1686
|
+
payload = response.json()
|
|
1687
|
+
if not isinstance(payload, dict):
|
|
1688
|
+
raise AuthFailure(
|
|
1689
|
+
response.status_code, "JSON returned from auth endpoint was invalid"
|
|
1690
|
+
)
|
|
1691
|
+
return payload
|
|
1692
|
+
|
|
1693
|
+
async def _perform_presence_history_request(
|
|
1694
|
+
self,
|
|
1695
|
+
endpoint: str,
|
|
1696
|
+
headers: Dict[str, str],
|
|
1697
|
+
headers_provider: Optional[PresenceHistoryHeadersProvider],
|
|
1698
|
+
channel_name: str,
|
|
1699
|
+
params: Dict[str, Any],
|
|
1700
|
+
action: str,
|
|
1701
|
+
) -> Dict[str, Any]:
|
|
1702
|
+
merged_headers = {"Content-Type": "application/json", **headers}
|
|
1703
|
+
if headers_provider:
|
|
1704
|
+
merged_headers.update(headers_provider())
|
|
1705
|
+
response = await self._http_client.post(
|
|
1706
|
+
endpoint,
|
|
1707
|
+
headers=merged_headers,
|
|
1708
|
+
content=json.dumps(
|
|
1709
|
+
{
|
|
1710
|
+
"channel": channel_name,
|
|
1711
|
+
"params": params,
|
|
1712
|
+
"action": action,
|
|
1713
|
+
}
|
|
1714
|
+
),
|
|
1715
|
+
)
|
|
1716
|
+
if response.status_code >= 400:
|
|
1717
|
+
raise SockudoException(
|
|
1718
|
+
f"Presence {action} request failed ({response.status_code}): "
|
|
1719
|
+
f"{response.text}"
|
|
1720
|
+
)
|
|
1721
|
+
payload = response.json()
|
|
1722
|
+
if not isinstance(payload, dict):
|
|
1723
|
+
raise SockudoException(f"Presence {action} endpoint returned invalid JSON")
|
|
1724
|
+
return payload
|
|
1725
|
+
|
|
1726
|
+
def _decode_presence_history_page(
|
|
1727
|
+
self,
|
|
1728
|
+
payload: Dict[str, Any],
|
|
1729
|
+
fetch_next: Callable[[str], Awaitable[PresenceHistoryPage]],
|
|
1730
|
+
) -> PresenceHistoryPage:
|
|
1731
|
+
return PresenceHistoryPage(
|
|
1732
|
+
items=[
|
|
1733
|
+
PresenceHistoryItem(
|
|
1734
|
+
stream_id=str(item["stream_id"]),
|
|
1735
|
+
serial=int(item["serial"]),
|
|
1736
|
+
published_at_ms=int(item["published_at_ms"]),
|
|
1737
|
+
event=str(item["event"]),
|
|
1738
|
+
cause=str(item["cause"]),
|
|
1739
|
+
user_id=str(item["user_id"]),
|
|
1740
|
+
connection_id=(
|
|
1741
|
+
str(item["connection_id"])
|
|
1742
|
+
if item.get("connection_id") is not None
|
|
1743
|
+
else None
|
|
1744
|
+
),
|
|
1745
|
+
dead_node_id=(
|
|
1746
|
+
str(item["dead_node_id"])
|
|
1747
|
+
if item.get("dead_node_id") is not None
|
|
1748
|
+
else None
|
|
1749
|
+
),
|
|
1750
|
+
payload_size_bytes=int(item["payload_size_bytes"]),
|
|
1751
|
+
presence_event=dict(item.get("presence_event") or {}),
|
|
1752
|
+
)
|
|
1753
|
+
for item in payload.get("items", [])
|
|
1754
|
+
if isinstance(item, dict)
|
|
1755
|
+
],
|
|
1756
|
+
direction=str(payload.get("direction", "oldest_first")),
|
|
1757
|
+
limit=int(payload.get("limit", 0)),
|
|
1758
|
+
has_more=bool(payload.get("has_more", False)),
|
|
1759
|
+
next_cursor=(
|
|
1760
|
+
str(payload["next_cursor"])
|
|
1761
|
+
if payload.get("next_cursor") is not None
|
|
1762
|
+
else None
|
|
1763
|
+
),
|
|
1764
|
+
bounds=self._decode_presence_history_bounds(payload.get("bounds")),
|
|
1765
|
+
continuity=self._decode_presence_history_continuity(
|
|
1766
|
+
payload.get("continuity")
|
|
1767
|
+
),
|
|
1768
|
+
_fetch_next=fetch_next,
|
|
1769
|
+
)
|
|
1770
|
+
|
|
1771
|
+
def _decode_presence_snapshot(self, payload: Dict[str, Any]) -> PresenceSnapshot:
|
|
1772
|
+
return PresenceSnapshot(
|
|
1773
|
+
channel=str(payload.get("channel", "")),
|
|
1774
|
+
members=[
|
|
1775
|
+
PresenceSnapshotMember(
|
|
1776
|
+
user_id=str(member["user_id"]),
|
|
1777
|
+
last_event=str(member["last_event"]),
|
|
1778
|
+
last_event_serial=int(member["last_event_serial"]),
|
|
1779
|
+
last_event_at_ms=int(member["last_event_at_ms"]),
|
|
1780
|
+
)
|
|
1781
|
+
for member in payload.get("members", [])
|
|
1782
|
+
if isinstance(member, dict)
|
|
1783
|
+
],
|
|
1784
|
+
member_count=int(payload.get("member_count", 0)),
|
|
1785
|
+
events_replayed=int(payload.get("events_replayed", 0)),
|
|
1786
|
+
snapshot_serial=(
|
|
1787
|
+
int(payload["snapshot_serial"])
|
|
1788
|
+
if payload.get("snapshot_serial") is not None
|
|
1789
|
+
else None
|
|
1790
|
+
),
|
|
1791
|
+
snapshot_time_ms=(
|
|
1792
|
+
int(payload["snapshot_time_ms"])
|
|
1793
|
+
if payload.get("snapshot_time_ms") is not None
|
|
1794
|
+
else None
|
|
1795
|
+
),
|
|
1796
|
+
continuity=self._decode_presence_history_continuity(
|
|
1797
|
+
payload.get("continuity")
|
|
1798
|
+
),
|
|
1799
|
+
)
|
|
1800
|
+
|
|
1801
|
+
def _decode_presence_history_bounds(self, payload: Any) -> PresenceHistoryBounds:
|
|
1802
|
+
if not isinstance(payload, dict):
|
|
1803
|
+
payload = {}
|
|
1804
|
+
return PresenceHistoryBounds(
|
|
1805
|
+
start_serial=(
|
|
1806
|
+
int(payload["start_serial"])
|
|
1807
|
+
if payload.get("start_serial") is not None
|
|
1808
|
+
else None
|
|
1809
|
+
),
|
|
1810
|
+
end_serial=(
|
|
1811
|
+
int(payload["end_serial"])
|
|
1812
|
+
if payload.get("end_serial") is not None
|
|
1813
|
+
else None
|
|
1814
|
+
),
|
|
1815
|
+
start_time_ms=(
|
|
1816
|
+
int(payload["start_time_ms"])
|
|
1817
|
+
if payload.get("start_time_ms") is not None
|
|
1818
|
+
else None
|
|
1819
|
+
),
|
|
1820
|
+
end_time_ms=(
|
|
1821
|
+
int(payload["end_time_ms"])
|
|
1822
|
+
if payload.get("end_time_ms") is not None
|
|
1823
|
+
else None
|
|
1824
|
+
),
|
|
1825
|
+
)
|
|
1826
|
+
|
|
1827
|
+
def _decode_presence_history_continuity(
|
|
1828
|
+
self, payload: Any
|
|
1829
|
+
) -> PresenceHistoryContinuity:
|
|
1830
|
+
if not isinstance(payload, dict):
|
|
1831
|
+
payload = {}
|
|
1832
|
+
return PresenceHistoryContinuity(
|
|
1833
|
+
stream_id=(
|
|
1834
|
+
str(payload["stream_id"])
|
|
1835
|
+
if payload.get("stream_id") is not None
|
|
1836
|
+
else None
|
|
1837
|
+
),
|
|
1838
|
+
oldest_available_serial=(
|
|
1839
|
+
int(payload["oldest_available_serial"])
|
|
1840
|
+
if payload.get("oldest_available_serial") is not None
|
|
1841
|
+
else None
|
|
1842
|
+
),
|
|
1843
|
+
newest_available_serial=(
|
|
1844
|
+
int(payload["newest_available_serial"])
|
|
1845
|
+
if payload.get("newest_available_serial") is not None
|
|
1846
|
+
else None
|
|
1847
|
+
),
|
|
1848
|
+
oldest_available_published_at_ms=(
|
|
1849
|
+
int(payload["oldest_available_published_at_ms"])
|
|
1850
|
+
if payload.get("oldest_available_published_at_ms") is not None
|
|
1851
|
+
else None
|
|
1852
|
+
),
|
|
1853
|
+
newest_available_published_at_ms=(
|
|
1854
|
+
int(payload["newest_available_published_at_ms"])
|
|
1855
|
+
if payload.get("newest_available_published_at_ms") is not None
|
|
1856
|
+
else None
|
|
1857
|
+
),
|
|
1858
|
+
retained_events=int(payload.get("retained_events", 0)),
|
|
1859
|
+
retained_bytes=int(payload.get("retained_bytes", 0)),
|
|
1860
|
+
degraded=bool(payload.get("degraded", False)),
|
|
1861
|
+
complete=bool(payload.get("complete", False)),
|
|
1862
|
+
truncated_by_retention=bool(payload.get("truncated_by_retention", False)),
|
|
1863
|
+
)
|
|
1864
|
+
|
|
1865
|
+
|
|
1866
|
+
class SockudoClient:
|
|
1867
|
+
def __init__(self, key: str, options: SockudoOptions) -> None:
|
|
1868
|
+
if not key:
|
|
1869
|
+
raise InvalidAppKey(
|
|
1870
|
+
"You must pass your app key when you instantiate SockudoClient."
|
|
1871
|
+
)
|
|
1872
|
+
if not options.cluster:
|
|
1873
|
+
raise InvalidOptions("Options must provide a cluster.")
|
|
1874
|
+
self.key = key
|
|
1875
|
+
self.options = options
|
|
1876
|
+
self.prefix = ProtocolPrefix(options.protocol_version)
|
|
1877
|
+
self.config = _ResolvedConfiguration(options)
|
|
1878
|
+
self.dispatcher = EventDispatcher()
|
|
1879
|
+
self.channels: Dict[str, SockudoChannel] = {}
|
|
1880
|
+
self.socket = None
|
|
1881
|
+
self.connection_state = ConnectionState.INITIALIZED
|
|
1882
|
+
self.socket_id: Optional[str] = None
|
|
1883
|
+
self._receive_task: Optional[asyncio.Task[Any]] = None
|
|
1884
|
+
self._activity_task: Optional[asyncio.Task[Any]] = None
|
|
1885
|
+
self._retry_task: Optional[asyncio.Task[Any]] = None
|
|
1886
|
+
self._unavailable_task: Optional[asyncio.Task[Any]] = None
|
|
1887
|
+
self._manually_disconnected = False
|
|
1888
|
+
self._current_transport: Optional[SockudoTransport] = None
|
|
1889
|
+
self._attempted_fallback = False
|
|
1890
|
+
self._channel_positions: Dict[str, RecoveryPosition] = {}
|
|
1891
|
+
self._deduplicator = (
|
|
1892
|
+
MessageDeduplicator(options.message_deduplication_capacity)
|
|
1893
|
+
if options.message_deduplication
|
|
1894
|
+
else None
|
|
1895
|
+
)
|
|
1896
|
+
self._delta_manager = (
|
|
1897
|
+
DeltaCompressionManager(
|
|
1898
|
+
options.delta_compression, self.send_event, self.prefix
|
|
1899
|
+
)
|
|
1900
|
+
if options.delta_compression
|
|
1901
|
+
else None
|
|
1902
|
+
)
|
|
1903
|
+
self.user = self.UserFacade(self)
|
|
1904
|
+
self.watchlist = self.WatchlistFacade()
|
|
1905
|
+
|
|
1906
|
+
def bind(
|
|
1907
|
+
self, event_name: str, callback: Callable[[Any, Optional[EventMetadata]], None]
|
|
1908
|
+
) -> str:
|
|
1909
|
+
return self.dispatcher.bind(event_name, callback)
|
|
1910
|
+
|
|
1911
|
+
def bind_global(self, callback: Callable[[str, Any], None]) -> str:
|
|
1912
|
+
return self.dispatcher.bind_global(callback)
|
|
1913
|
+
|
|
1914
|
+
def channel(self, name: str) -> Optional[SockudoChannel]:
|
|
1915
|
+
return self.channels.get(name)
|
|
1916
|
+
|
|
1917
|
+
def subscribe(
|
|
1918
|
+
self, channel_name: str, options: Optional[SubscriptionOptions] = None
|
|
1919
|
+
) -> SockudoChannel:
|
|
1920
|
+
channel = self.channels.get(channel_name)
|
|
1921
|
+
if channel is None:
|
|
1922
|
+
channel = self._create_channel(channel_name)
|
|
1923
|
+
self.channels[channel_name] = channel
|
|
1924
|
+
if options is not None:
|
|
1925
|
+
channel.filter = options.filter
|
|
1926
|
+
channel.delta_settings = options.delta
|
|
1927
|
+
channel.events_filter = options.events
|
|
1928
|
+
channel.rewind = options.rewind
|
|
1929
|
+
channel.subscribe_if_possible()
|
|
1930
|
+
return channel
|
|
1931
|
+
|
|
1932
|
+
async def unsubscribe(self, channel_name: str) -> None:
|
|
1933
|
+
channel = self.channels.get(channel_name)
|
|
1934
|
+
if channel is None:
|
|
1935
|
+
return
|
|
1936
|
+
if channel.subscription_pending:
|
|
1937
|
+
channel.subscription_cancelled = True
|
|
1938
|
+
elif channel.is_subscribed:
|
|
1939
|
+
self.channels.pop(channel_name, None)
|
|
1940
|
+
await channel.unsubscribe()
|
|
1941
|
+
else:
|
|
1942
|
+
self.channels.pop(channel_name, None)
|
|
1943
|
+
self._channel_positions.pop(channel_name, None)
|
|
1944
|
+
if self._delta_manager is not None:
|
|
1945
|
+
self._delta_manager.clear_channel_state(channel_name)
|
|
1946
|
+
|
|
1947
|
+
async def connect(self) -> None:
|
|
1948
|
+
if self.socket is not None:
|
|
1949
|
+
return
|
|
1950
|
+
transports = self._transport_sequence()
|
|
1951
|
+
if not transports:
|
|
1952
|
+
self._update_state(ConnectionState.FAILED)
|
|
1953
|
+
return
|
|
1954
|
+
self._manually_disconnected = False
|
|
1955
|
+
self._attempted_fallback = False
|
|
1956
|
+
self._update_state(ConnectionState.CONNECTING)
|
|
1957
|
+
await self._open_websocket(transports[0])
|
|
1958
|
+
self._set_unavailable_timer()
|
|
1959
|
+
|
|
1960
|
+
async def disconnect(self) -> None:
|
|
1961
|
+
self._manually_disconnected = True
|
|
1962
|
+
self._cancel_timers()
|
|
1963
|
+
if self.socket is not None:
|
|
1964
|
+
await self.socket.close()
|
|
1965
|
+
self.socket = None
|
|
1966
|
+
for channel in self.channels.values():
|
|
1967
|
+
channel.disconnect()
|
|
1968
|
+
self._update_state(ConnectionState.DISCONNECTED)
|
|
1969
|
+
await self.config.close()
|
|
1970
|
+
|
|
1971
|
+
async def close(self) -> None:
|
|
1972
|
+
await self.disconnect()
|
|
1973
|
+
|
|
1974
|
+
async def signin(self) -> None:
|
|
1975
|
+
await self.user.sign_in()
|
|
1976
|
+
|
|
1977
|
+
async def send_event(self, name: str, data: Any, channel: Optional[str]) -> bool:
|
|
1978
|
+
if self.socket is None:
|
|
1979
|
+
return False
|
|
1980
|
+
payload: Dict[str, Any] = {"event": name, "data": data}
|
|
1981
|
+
if channel is not None:
|
|
1982
|
+
payload["channel"] = channel
|
|
1983
|
+
encoded = ProtocolCodec.encode_envelope(payload, self.options.wire_format)
|
|
1984
|
+
await self.socket.send(encoded)
|
|
1985
|
+
return True
|
|
1986
|
+
|
|
1987
|
+
def get_delta_stats(self) -> Optional[DeltaStats]:
|
|
1988
|
+
return self._delta_manager.get_stats() if self._delta_manager else None
|
|
1989
|
+
|
|
1990
|
+
async def _open_websocket(self, transport: SockudoTransport) -> None:
|
|
1991
|
+
self._current_transport = transport
|
|
1992
|
+
self.socket = await ws_connect(self._socket_url(transport))
|
|
1993
|
+
self._receive_task = asyncio.create_task(self._receive_loop())
|
|
1994
|
+
|
|
1995
|
+
async def _receive_loop(self) -> None:
|
|
1996
|
+
try:
|
|
1997
|
+
assert self.socket is not None
|
|
1998
|
+
async for raw_message in self.socket:
|
|
1999
|
+
await self._handle_raw_message(raw_message)
|
|
2000
|
+
except ConnectionClosed as exc:
|
|
2001
|
+
await self._handle_socket_closed(exc.code, exc.reason)
|
|
2002
|
+
|
|
2003
|
+
async def _handle_raw_message(self, raw_message: Union[str, bytes]) -> None:
|
|
2004
|
+
try:
|
|
2005
|
+
event = ProtocolCodec.decode_event(raw_message, self.options.wire_format)
|
|
2006
|
+
if event.message_id and self._deduplicator:
|
|
2007
|
+
if self._deduplicator.is_duplicate(event.message_id):
|
|
2008
|
+
return
|
|
2009
|
+
self._deduplicator.track(event.message_id)
|
|
2010
|
+
self._reset_activity_timer()
|
|
2011
|
+
if (
|
|
2012
|
+
self.options.connection_recovery
|
|
2013
|
+
and event.channel
|
|
2014
|
+
and event.serial is not None
|
|
2015
|
+
):
|
|
2016
|
+
self._channel_positions[event.channel] = RecoveryPosition(
|
|
2017
|
+
serial=event.serial,
|
|
2018
|
+
stream_id=event.stream_id,
|
|
2019
|
+
last_message_id=event.message_id,
|
|
2020
|
+
)
|
|
2021
|
+
event_name = event.event
|
|
2022
|
+
if event_name == self.prefix.event("connection_established"):
|
|
2023
|
+
payload = event.data if isinstance(event.data, dict) else {}
|
|
2024
|
+
self.socket_id = payload.get("socket_id")
|
|
2025
|
+
if not isinstance(self.socket_id, str):
|
|
2026
|
+
raise SockudoException("Invalid handshake")
|
|
2027
|
+
self._update_state(
|
|
2028
|
+
ConnectionState.CONNECTED, {"socket_id": self.socket_id}
|
|
2029
|
+
)
|
|
2030
|
+
for channel in self.channels.values():
|
|
2031
|
+
channel.subscribe_if_possible()
|
|
2032
|
+
if self.options.connection_recovery and self._channel_positions:
|
|
2033
|
+
await self.send_event(
|
|
2034
|
+
self.prefix.event("resume"),
|
|
2035
|
+
{
|
|
2036
|
+
"channel_positions": {
|
|
2037
|
+
channel_name: {
|
|
2038
|
+
key: value
|
|
2039
|
+
for key, value in {
|
|
2040
|
+
"serial": position.serial,
|
|
2041
|
+
"stream_id": position.stream_id,
|
|
2042
|
+
"last_message_id": position.last_message_id,
|
|
2043
|
+
}.items()
|
|
2044
|
+
if value is not None
|
|
2045
|
+
}
|
|
2046
|
+
for channel_name, position in self._channel_positions.items()
|
|
2047
|
+
}
|
|
2048
|
+
},
|
|
2049
|
+
None,
|
|
2050
|
+
)
|
|
2051
|
+
if (
|
|
2052
|
+
self.options.delta_compression
|
|
2053
|
+
and self.options.delta_compression.enabled
|
|
2054
|
+
):
|
|
2055
|
+
assert self._delta_manager is not None
|
|
2056
|
+
await self._delta_manager.enable()
|
|
2057
|
+
await self.user.handle_connected()
|
|
2058
|
+
elif event_name == self.prefix.event("error"):
|
|
2059
|
+
self.dispatcher.emit("error", event.data)
|
|
2060
|
+
elif event_name == self.prefix.event("ping"):
|
|
2061
|
+
await self.send_event(self.prefix.event("pong"), {}, None)
|
|
2062
|
+
elif event_name == self.prefix.event("signin_success"):
|
|
2063
|
+
await self.user.handle_sign_in_success(event.data)
|
|
2064
|
+
elif event_name == self.prefix.event("resume_failed"):
|
|
2065
|
+
payload = event.data if isinstance(event.data, dict) else {}
|
|
2066
|
+
failed_channel_name = payload.get("channel")
|
|
2067
|
+
if isinstance(failed_channel_name, str):
|
|
2068
|
+
self._channel_positions.pop(failed_channel_name, None)
|
|
2069
|
+
if self._delta_manager is not None:
|
|
2070
|
+
self._delta_manager.clear_channel_state(failed_channel_name)
|
|
2071
|
+
failed_channel = self.channels.get(failed_channel_name)
|
|
2072
|
+
if failed_channel is not None:
|
|
2073
|
+
failed_channel.is_subscribed = False
|
|
2074
|
+
failed_channel.subscription_pending = False
|
|
2075
|
+
failed_channel.subscribe_if_possible()
|
|
2076
|
+
self.dispatcher.emit(event_name, event.data)
|
|
2077
|
+
elif event_name == self.prefix.internal("watchlist_events"):
|
|
2078
|
+
self.watchlist.handle(event.data)
|
|
2079
|
+
elif (
|
|
2080
|
+
event_name == self.prefix.event("delta_compression_enabled")
|
|
2081
|
+
and self._delta_manager
|
|
2082
|
+
):
|
|
2083
|
+
self._delta_manager.handle_enabled(event.data)
|
|
2084
|
+
self.dispatcher.emit(event_name, event.data)
|
|
2085
|
+
elif (
|
|
2086
|
+
event_name == self.prefix.event("delta_cache_sync")
|
|
2087
|
+
and self._delta_manager
|
|
2088
|
+
and event.channel
|
|
2089
|
+
):
|
|
2090
|
+
self._delta_manager.handle_cache_sync(event.channel, event.data)
|
|
2091
|
+
elif (
|
|
2092
|
+
event_name == self.prefix.event("delta")
|
|
2093
|
+
and self._delta_manager
|
|
2094
|
+
and event.channel
|
|
2095
|
+
):
|
|
2096
|
+
reconstructed = await self._delta_manager.handle_delta_message(
|
|
2097
|
+
event.channel, event.data
|
|
2098
|
+
)
|
|
2099
|
+
if reconstructed is not None:
|
|
2100
|
+
channel = self.channels.get(event.channel)
|
|
2101
|
+
if channel is not None:
|
|
2102
|
+
channel.handle(reconstructed)
|
|
2103
|
+
self.dispatcher.emit(reconstructed.event, reconstructed.data)
|
|
2104
|
+
else:
|
|
2105
|
+
if event.channel and event.channel in self.channels:
|
|
2106
|
+
self.channels[event.channel].handle(event)
|
|
2107
|
+
if (
|
|
2108
|
+
not self.prefix.is_platform_event(event_name)
|
|
2109
|
+
and not self.prefix.is_internal_event(event_name)
|
|
2110
|
+
and event.sequence is not None
|
|
2111
|
+
and self._delta_manager is not None
|
|
2112
|
+
):
|
|
2113
|
+
self._delta_manager.handle_full_message(
|
|
2114
|
+
event.channel,
|
|
2115
|
+
self._strip_delta_metadata(event.raw_message),
|
|
2116
|
+
event.sequence,
|
|
2117
|
+
event.conflation_key,
|
|
2118
|
+
)
|
|
2119
|
+
if not self.prefix.is_internal_event(event_name):
|
|
2120
|
+
self.dispatcher.emit(
|
|
2121
|
+
event_name, event.data, EventMetadata(user_id=event.user_id)
|
|
2122
|
+
)
|
|
2123
|
+
except BaseException as exc:
|
|
2124
|
+
self.dispatcher.emit("error", exc)
|
|
2125
|
+
|
|
2126
|
+
async def _handle_socket_closed(self, code: int, reason: str) -> None:
|
|
2127
|
+
self.socket = None
|
|
2128
|
+
self._cancel_activity_timer()
|
|
2129
|
+
self._clear_unavailable_timer()
|
|
2130
|
+
for channel in self.channels.values():
|
|
2131
|
+
channel.disconnect()
|
|
2132
|
+
if not self._manually_disconnected:
|
|
2133
|
+
await self._schedule_retry(1.0)
|
|
2134
|
+
if reason:
|
|
2135
|
+
self.dispatcher.emit("error", reason)
|
|
2136
|
+
|
|
2137
|
+
async def _schedule_retry(self, after_seconds: float) -> None:
|
|
2138
|
+
if self._manually_disconnected:
|
|
2139
|
+
return
|
|
2140
|
+
if self._retry_task:
|
|
2141
|
+
self._retry_task.cancel()
|
|
2142
|
+
|
|
2143
|
+
async def _retry() -> None:
|
|
2144
|
+
await asyncio.sleep(after_seconds)
|
|
2145
|
+
self._update_state(ConnectionState.CONNECTING)
|
|
2146
|
+
transports = self._transport_sequence()
|
|
2147
|
+
if (
|
|
2148
|
+
self._current_transport is SockudoTransport.WS
|
|
2149
|
+
and not self._attempted_fallback
|
|
2150
|
+
and SockudoTransport.WSS in transports
|
|
2151
|
+
):
|
|
2152
|
+
self._attempted_fallback = True
|
|
2153
|
+
await self._open_websocket(SockudoTransport.WSS)
|
|
2154
|
+
else:
|
|
2155
|
+
self._attempted_fallback = False
|
|
2156
|
+
await self._open_websocket(
|
|
2157
|
+
transports[0] if transports else SockudoTransport.WSS
|
|
2158
|
+
)
|
|
2159
|
+
self._set_unavailable_timer()
|
|
2160
|
+
|
|
2161
|
+
self._retry_task = asyncio.create_task(_retry())
|
|
2162
|
+
|
|
2163
|
+
def _socket_url(self, transport: SockudoTransport) -> str:
|
|
2164
|
+
scheme = "wss" if transport is SockudoTransport.WSS else "ws"
|
|
2165
|
+
host = self.config.ws_host
|
|
2166
|
+
port = (
|
|
2167
|
+
self.config.wss_port
|
|
2168
|
+
if transport is SockudoTransport.WSS
|
|
2169
|
+
else self.config.ws_port
|
|
2170
|
+
)
|
|
2171
|
+
path = f"{self.config.ws_path}/app/{self.key}"
|
|
2172
|
+
query = {
|
|
2173
|
+
"protocol": self.prefix.version,
|
|
2174
|
+
"client": "python",
|
|
2175
|
+
"version": "2.0.0",
|
|
2176
|
+
"flash": "false",
|
|
2177
|
+
}
|
|
2178
|
+
if self.options.protocol_version >= 2:
|
|
2179
|
+
query["format"] = self.options.wire_format.value
|
|
2180
|
+
query["echo_messages"] = "true" if self.options.echo_messages else "false"
|
|
2181
|
+
return urllib.parse.urlunsplit(
|
|
2182
|
+
(scheme, f"{host}:{port}", path, urllib.parse.urlencode(query), "")
|
|
2183
|
+
)
|
|
2184
|
+
|
|
2185
|
+
def _transport_sequence(self) -> List[SockudoTransport]:
|
|
2186
|
+
transports = (
|
|
2187
|
+
[SockudoTransport.WSS]
|
|
2188
|
+
if self.config.use_tls
|
|
2189
|
+
else [SockudoTransport.WS, SockudoTransport.WSS]
|
|
2190
|
+
)
|
|
2191
|
+
if self.config.enabled_transports is not None:
|
|
2192
|
+
transports = [
|
|
2193
|
+
transport
|
|
2194
|
+
for transport in transports
|
|
2195
|
+
if transport in self.config.enabled_transports
|
|
2196
|
+
]
|
|
2197
|
+
if self.config.disabled_transports is not None:
|
|
2198
|
+
transports = [
|
|
2199
|
+
transport
|
|
2200
|
+
for transport in transports
|
|
2201
|
+
if transport not in self.config.disabled_transports
|
|
2202
|
+
]
|
|
2203
|
+
return transports
|
|
2204
|
+
|
|
2205
|
+
def _create_channel(self, name: str) -> SockudoChannel:
|
|
2206
|
+
if name.startswith("private-encrypted-"):
|
|
2207
|
+
return EncryptedChannel(name, self)
|
|
2208
|
+
if name.startswith("presence-"):
|
|
2209
|
+
return PresenceChannel(name, self)
|
|
2210
|
+
if name.startswith("private-"):
|
|
2211
|
+
return PrivateChannel(name, self)
|
|
2212
|
+
return SockudoChannel(name, self)
|
|
2213
|
+
|
|
2214
|
+
def _update_state(
|
|
2215
|
+
self, state: ConnectionState, metadata: Optional[Dict[str, Any]] = None
|
|
2216
|
+
) -> None:
|
|
2217
|
+
previous = self.connection_state
|
|
2218
|
+
self.connection_state = state
|
|
2219
|
+
self.dispatcher.emit(
|
|
2220
|
+
"state_change", {"previous": previous.value, "current": state.value}
|
|
2221
|
+
)
|
|
2222
|
+
self.dispatcher.emit(state.value, metadata)
|
|
2223
|
+
|
|
2224
|
+
def _cancel_activity_timer(self) -> None:
|
|
2225
|
+
if self._activity_task:
|
|
2226
|
+
self._activity_task.cancel()
|
|
2227
|
+
self._activity_task = None
|
|
2228
|
+
|
|
2229
|
+
def _reset_activity_timer(self) -> None:
|
|
2230
|
+
self._cancel_activity_timer()
|
|
2231
|
+
|
|
2232
|
+
async def _timer() -> None:
|
|
2233
|
+
await asyncio.sleep(self.config.activity_timeout)
|
|
2234
|
+
await self.send_event(self.prefix.event("ping"), {}, None)
|
|
2235
|
+
|
|
2236
|
+
self._activity_task = asyncio.create_task(_timer())
|
|
2237
|
+
|
|
2238
|
+
def _set_unavailable_timer(self) -> None:
|
|
2239
|
+
self._clear_unavailable_timer()
|
|
2240
|
+
|
|
2241
|
+
async def _timer() -> None:
|
|
2242
|
+
await asyncio.sleep(self.config.unavailable_timeout)
|
|
2243
|
+
self._update_state(ConnectionState.UNAVAILABLE)
|
|
2244
|
+
|
|
2245
|
+
self._unavailable_task = asyncio.create_task(_timer())
|
|
2246
|
+
|
|
2247
|
+
def _clear_unavailable_timer(self) -> None:
|
|
2248
|
+
if self._unavailable_task:
|
|
2249
|
+
self._unavailable_task.cancel()
|
|
2250
|
+
self._unavailable_task = None
|
|
2251
|
+
|
|
2252
|
+
def _cancel_timers(self) -> None:
|
|
2253
|
+
self._cancel_activity_timer()
|
|
2254
|
+
self._clear_unavailable_timer()
|
|
2255
|
+
if self._retry_task:
|
|
2256
|
+
self._retry_task.cancel()
|
|
2257
|
+
self._retry_task = None
|
|
2258
|
+
|
|
2259
|
+
@staticmethod
|
|
2260
|
+
def _strip_delta_metadata(raw_message: str) -> str:
|
|
2261
|
+
return raw_message.replace(',"__delta_seq"', "").replace(
|
|
2262
|
+
',"__conflation_key"', ""
|
|
2263
|
+
)
|
|
2264
|
+
|
|
2265
|
+
def reset_delta_stats(self) -> None:
|
|
2266
|
+
if self._delta_manager is not None:
|
|
2267
|
+
self._delta_manager.reset_stats()
|
|
2268
|
+
|
|
2269
|
+
class UserFacade:
|
|
2270
|
+
def __init__(self, client: "SockudoClient") -> None:
|
|
2271
|
+
self.client = client
|
|
2272
|
+
self.dispatcher = EventDispatcher()
|
|
2273
|
+
self.is_sign_in_requested = False
|
|
2274
|
+
self.user_data: Optional[Dict[str, Any]] = None
|
|
2275
|
+
self.server_channel: Optional[SockudoChannel] = None
|
|
2276
|
+
|
|
2277
|
+
@property
|
|
2278
|
+
def user_id(self) -> Optional[str]:
|
|
2279
|
+
if self.user_data is None:
|
|
2280
|
+
return None
|
|
2281
|
+
value = self.user_data.get("id")
|
|
2282
|
+
return value if isinstance(value, str) else None
|
|
2283
|
+
|
|
2284
|
+
def bind(
|
|
2285
|
+
self,
|
|
2286
|
+
event_name: str,
|
|
2287
|
+
callback: Callable[[Any, Optional[EventMetadata]], None],
|
|
2288
|
+
) -> str:
|
|
2289
|
+
return self.dispatcher.bind(event_name, callback)
|
|
2290
|
+
|
|
2291
|
+
async def sign_in(self) -> None:
|
|
2292
|
+
self.is_sign_in_requested = True
|
|
2293
|
+
await self._attempt_sign_in()
|
|
2294
|
+
|
|
2295
|
+
async def handle_connected(self) -> None:
|
|
2296
|
+
await self._attempt_sign_in()
|
|
2297
|
+
|
|
2298
|
+
async def handle_sign_in_success(self, data: Any) -> None:
|
|
2299
|
+
payload = data if isinstance(data, dict) else {}
|
|
2300
|
+
user_data = payload.get("user_data")
|
|
2301
|
+
if not isinstance(user_data, str):
|
|
2302
|
+
self._cleanup()
|
|
2303
|
+
return
|
|
2304
|
+
parsed = json.loads(user_data)
|
|
2305
|
+
if not isinstance(parsed, dict) or not isinstance(parsed.get("id"), str):
|
|
2306
|
+
self._cleanup()
|
|
2307
|
+
return
|
|
2308
|
+
self.user_data = parsed
|
|
2309
|
+
await self._subscribe_server_channel(parsed["id"])
|
|
2310
|
+
|
|
2311
|
+
async def _attempt_sign_in(self) -> None:
|
|
2312
|
+
if (
|
|
2313
|
+
not self.is_sign_in_requested
|
|
2314
|
+
or self.client.connection_state is not ConnectionState.CONNECTED
|
|
2315
|
+
):
|
|
2316
|
+
return
|
|
2317
|
+
if not self.client.socket_id:
|
|
2318
|
+
return
|
|
2319
|
+
try:
|
|
2320
|
+
auth = await self.client.config.authenticate_user(
|
|
2321
|
+
UserAuthenticationRequest(self.client.socket_id)
|
|
2322
|
+
)
|
|
2323
|
+
await self.client.send_event(
|
|
2324
|
+
self.client.prefix.event("signin"),
|
|
2325
|
+
{"auth": auth.auth, "user_data": auth.user_data},
|
|
2326
|
+
None,
|
|
2327
|
+
)
|
|
2328
|
+
except BaseException:
|
|
2329
|
+
self._cleanup()
|
|
2330
|
+
|
|
2331
|
+
async def _subscribe_server_channel(self, user_id: str) -> None:
|
|
2332
|
+
channel = SockudoChannel(f"#server-to-user-{user_id}", self.client)
|
|
2333
|
+
channel.bind_global(
|
|
2334
|
+
lambda event_name, data: (
|
|
2335
|
+
self.dispatcher.emit(event_name, data)
|
|
2336
|
+
if not self.client.prefix.is_internal_event(event_name)
|
|
2337
|
+
and not self.client.prefix.is_platform_event(event_name)
|
|
2338
|
+
else None
|
|
2339
|
+
)
|
|
2340
|
+
)
|
|
2341
|
+
self.server_channel = channel
|
|
2342
|
+
channel.subscribe_if_possible()
|
|
2343
|
+
|
|
2344
|
+
def _cleanup(self) -> None:
|
|
2345
|
+
self.user_data = None
|
|
2346
|
+
if self.server_channel is not None:
|
|
2347
|
+
self.server_channel.unbind()
|
|
2348
|
+
self.server_channel.disconnect()
|
|
2349
|
+
self.server_channel = None
|
|
2350
|
+
|
|
2351
|
+
class WatchlistFacade:
|
|
2352
|
+
def __init__(self) -> None:
|
|
2353
|
+
self.dispatcher = EventDispatcher()
|
|
2354
|
+
|
|
2355
|
+
def bind(
|
|
2356
|
+
self,
|
|
2357
|
+
event_name: str,
|
|
2358
|
+
callback: Callable[[Any, Optional[EventMetadata]], None],
|
|
2359
|
+
) -> str:
|
|
2360
|
+
return self.dispatcher.bind(event_name, callback)
|
|
2361
|
+
|
|
2362
|
+
def handle(self, data: Any) -> None:
|
|
2363
|
+
payload = data if isinstance(data, dict) else {}
|
|
2364
|
+
events = payload.get("events", [])
|
|
2365
|
+
if not isinstance(events, list):
|
|
2366
|
+
return
|
|
2367
|
+
for event in events:
|
|
2368
|
+
if isinstance(event, dict) and isinstance(event.get("name"), str):
|
|
2369
|
+
self.dispatcher.emit(event["name"], event)
|
|
2370
|
+
|
|
2371
|
+
|
|
2372
|
+
def _coerce_int(value: Any) -> Optional[int]:
|
|
2373
|
+
if isinstance(value, bool):
|
|
2374
|
+
return None
|
|
2375
|
+
if isinstance(value, int):
|
|
2376
|
+
return value
|
|
2377
|
+
if isinstance(value, float):
|
|
2378
|
+
return int(value)
|
|
2379
|
+
return None
|
|
2380
|
+
|
|
2381
|
+
|
|
2382
|
+
def _write_varint(buffer: bytearray, value: int) -> None:
|
|
2383
|
+
while True:
|
|
2384
|
+
if value < 0x80:
|
|
2385
|
+
buffer.append(value)
|
|
2386
|
+
return
|
|
2387
|
+
buffer.append((value & 0x7F) | 0x80)
|
|
2388
|
+
value >>= 7
|
|
2389
|
+
|
|
2390
|
+
|
|
2391
|
+
def _write_key(buffer: bytearray, field: int, wire_type: int) -> None:
|
|
2392
|
+
_write_varint(buffer, (field << 3) | wire_type)
|
|
2393
|
+
|
|
2394
|
+
|
|
2395
|
+
def _write_string_field(buffer: bytearray, field: int, value: Any) -> None:
|
|
2396
|
+
if not isinstance(value, str):
|
|
2397
|
+
return
|
|
2398
|
+
encoded = value.encode("utf-8")
|
|
2399
|
+
_write_key(buffer, field, 2)
|
|
2400
|
+
_write_varint(buffer, len(encoded))
|
|
2401
|
+
buffer.extend(encoded)
|
|
2402
|
+
|
|
2403
|
+
|
|
2404
|
+
def _write_bytes_field(buffer: bytearray, field: int, payload: bytes) -> None:
|
|
2405
|
+
_write_key(buffer, field, 2)
|
|
2406
|
+
_write_varint(buffer, len(payload))
|
|
2407
|
+
buffer.extend(payload)
|
|
2408
|
+
|
|
2409
|
+
|
|
2410
|
+
def _write_uint_field(buffer: bytearray, field: int, value: Any) -> None:
|
|
2411
|
+
coerced = _coerce_int(value)
|
|
2412
|
+
if coerced is None:
|
|
2413
|
+
return
|
|
2414
|
+
_write_key(buffer, field, 0)
|
|
2415
|
+
_write_varint(buffer, coerced)
|
|
2416
|
+
|
|
2417
|
+
|
|
2418
|
+
def _write_optional_bool_field(
|
|
2419
|
+
buffer: bytearray, field: int, value: Optional[bool]
|
|
2420
|
+
) -> None:
|
|
2421
|
+
if value is None:
|
|
2422
|
+
return
|
|
2423
|
+
_write_key(buffer, field, 0)
|
|
2424
|
+
_write_varint(buffer, 1 if value else 0)
|
|
2425
|
+
|
|
2426
|
+
|
|
2427
|
+
def _write_bool_field(buffer: bytearray, field: int, value: bool) -> None:
|
|
2428
|
+
_write_key(buffer, field, 0)
|
|
2429
|
+
_write_varint(buffer, 1 if value else 0)
|
|
2430
|
+
|
|
2431
|
+
|
|
2432
|
+
def _write_double_field(buffer: bytearray, field: int, value: float) -> None:
|
|
2433
|
+
_write_key(buffer, field, 1)
|
|
2434
|
+
buffer.extend(struct.pack("<d", value))
|
|
2435
|
+
|
|
2436
|
+
|
|
2437
|
+
def _read_varint(data: bytes, index: int) -> Tuple[int, int]:
|
|
2438
|
+
shift = 0
|
|
2439
|
+
result = 0
|
|
2440
|
+
while True:
|
|
2441
|
+
byte = data[index]
|
|
2442
|
+
index += 1
|
|
2443
|
+
result |= (byte & 0x7F) << shift
|
|
2444
|
+
if byte & 0x80 == 0:
|
|
2445
|
+
return result, index
|
|
2446
|
+
shift += 7
|
|
2447
|
+
|
|
2448
|
+
|
|
2449
|
+
def _read_length_delimited(data: bytes, index: int) -> Tuple[bytes, int]:
|
|
2450
|
+
length, index = _read_varint(data, index)
|
|
2451
|
+
return data[index : index + length], index + length
|
|
2452
|
+
|
|
2453
|
+
|
|
2454
|
+
def _skip_unknown(data: bytes, index: int, wire: int) -> int:
|
|
2455
|
+
if wire == 0:
|
|
2456
|
+
_, index = _read_varint(data, index)
|
|
2457
|
+
return index
|
|
2458
|
+
if wire == 1:
|
|
2459
|
+
return index + 8
|
|
2460
|
+
if wire == 2:
|
|
2461
|
+
payload, index = _read_length_delimited(data, index)
|
|
2462
|
+
return index
|
|
2463
|
+
if wire == 5:
|
|
2464
|
+
return index + 4
|
|
2465
|
+
raise SockudoException(f"Unsupported protobuf wire type: {wire}")
|