pycupra 0.0.15__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,796 @@
1
+ from __future__ import annotations
2
+
3
+ """Functions to receive firebase cloud messaging notifications """
4
+ """Taken from https://github.com/sdb9696/firebase-messaging with small modifications"""
5
+
6
+ import asyncio
7
+ import contextlib
8
+ import json
9
+ import logging
10
+ import ssl
11
+ import struct
12
+ import time
13
+ import traceback
14
+ from base64 import urlsafe_b64decode
15
+ from contextlib import suppress as contextlib_suppress
16
+ from dataclasses import dataclass
17
+ from enum import Enum
18
+ from typing import TYPE_CHECKING, Any, Callable
19
+
20
+ from aiohttp import ClientSession
21
+ from cryptography.hazmat.backends import default_backend
22
+ from cryptography.hazmat.primitives.serialization import load_der_private_key
23
+ from google.protobuf.json_format import MessageToJson
24
+ from google.protobuf.message import Message
25
+ from http_ece import decrypt as http_decrypt # type: ignore[import-untyped]
26
+
27
+ from .const import (
28
+ MCS_HOST,
29
+ MCS_PORT,
30
+ MCS_SELECTIVE_ACK_ID,
31
+ MCS_VERSION,
32
+ )
33
+ from .fcmregister import FcmRegister, FcmRegisterConfig
34
+ from .proto.mcs_pb2 import ( # pylint: disable=no-name-in-module
35
+ Close,
36
+ DataMessageStanza,
37
+ HeartbeatAck,
38
+ HeartbeatPing,
39
+ IqStanza,
40
+ LoginRequest,
41
+ LoginResponse,
42
+ SelectiveAck,
43
+ StreamErrorStanza,
44
+ )
45
+
46
+ _logger = logging.getLogger(__name__)
47
+
48
+ OnNotificationCallable = Callable[[dict[str, Any], str, Any], None]
49
+ CredentialsUpdatedCallable = Callable[[dict[str, Any]], None]
50
+
51
+ # MCS Message Types and Tags
52
+ MCS_MESSAGE_TAG = {
53
+ HeartbeatPing: 0,
54
+ HeartbeatAck: 1,
55
+ LoginRequest: 2,
56
+ LoginResponse: 3,
57
+ Close: 4,
58
+ "MessageStanza": 5,
59
+ "PresenceStanza": 6,
60
+ IqStanza: 7,
61
+ DataMessageStanza: 8,
62
+ "BatchPresenceStanza": 9,
63
+ StreamErrorStanza: 10,
64
+ "HttpRequest": 11,
65
+ "HttpResponse": 12,
66
+ "BindAccountRequest": 13,
67
+ "BindAccountResponse": 14,
68
+ "TalkMetadata": 15,
69
+ }
70
+
71
+
72
+ class ErrorType(Enum):
73
+ CONNECTION = 1
74
+ READ = 2
75
+ LOGIN = 3
76
+ NOTIFY = 4
77
+
78
+
79
+ class FcmPushClientRunState(Enum):
80
+ CREATED = (1,)
81
+ STARTING_TASKS = (2,)
82
+ STARTING_CONNECTION = (3,)
83
+ STARTING_LOGIN = (4,)
84
+ STARTED = (5,)
85
+ RESETTING = (6,)
86
+ STOPPING = (7,)
87
+ STOPPED = (8,)
88
+
89
+
90
+ @dataclass
91
+ class FcmPushClientConfig: # pylint:disable=too-many-instance-attributes
92
+ """Class to provide configuration to
93
+ :class:`firebase_messaging.FcmPushClientConfig`.FcmPushClient."""
94
+
95
+ server_heartbeat_interval: int | None = 20 # original value was 10
96
+ """Time in seconds to request the server to send heartbeats"""
97
+
98
+ client_heartbeat_interval: int | None = 30 # original value was 20
99
+ """Time in seconds to send heartbeats to the server"""
100
+
101
+ send_selective_acknowledgements: bool = True
102
+ """True to send selective acknowledgements for each message received.
103
+ Currently if false the client does not send any acknowledgements."""
104
+
105
+ connection_retry_count: int = 5
106
+ """Number of times to retry the connection before giving up."""
107
+
108
+ start_seconds_before_retry_connect: float = 3
109
+ """Time in seconds to wait before attempting to retry
110
+ the connection after failure."""
111
+
112
+ reset_interval: float = 3
113
+ """Time in seconds to wait between resets after errors or disconnection."""
114
+
115
+ heartbeat_ack_timeout: float = 5
116
+ """Time in seconds to wait for a heartbeat ack before resetting."""
117
+
118
+ abort_on_sequential_error_count: int | None = 3
119
+ """Number of sequential errors of the same time to wait before aborting.
120
+ If set to None the client will not abort."""
121
+
122
+ monitor_interval: float = 1
123
+ """Time in seconds for the monitor task to fire and check for heartbeats,
124
+ stale connections and shut down of the main event loop."""
125
+
126
+ log_warn_limit: int | None = 5
127
+ """Number of times to log specific warning messages before going silent for
128
+ a specific warning type."""
129
+
130
+ log_debug_verbose: bool = False
131
+ """Set to True to log all message info including tokens."""
132
+
133
+
134
+ class FcmPushClient: # pylint:disable=too-many-instance-attributes
135
+ """Client that connects to Firebase Cloud Messaging and receives messages.
136
+
137
+ :param credentials: credentials object returned by register()
138
+ :param credentials_updated_callback: callback when new credentials are
139
+ created to allow client to store them
140
+ :param received_persistent_ids: any persistent id's you already received.
141
+ :param config: configuration class of
142
+ :class:`firebase_messaging.FcmPushClientConfig`
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ callback: Callable[[dict, str, Any | None], None],
148
+ fcm_config: FcmRegisterConfig,
149
+ credentials: dict | None = None,
150
+ credentials_updated_callback: CredentialsUpdatedCallable | None = None,
151
+ *,
152
+ callback_context: object | None = None,
153
+ received_persistent_ids: list[str] | None = None,
154
+ config: FcmPushClientConfig | None = None,
155
+ http_client_session: ClientSession | None = None,
156
+ ):
157
+ """Initializes the receiver."""
158
+ self.callback = callback
159
+ self.callback_context = callback_context
160
+ self.fcm_config = fcm_config
161
+ self.credentials = credentials
162
+ self.credentials_updated_callback = credentials_updated_callback
163
+ self.persistent_ids = received_persistent_ids if received_persistent_ids else []
164
+ self.config = config if config else FcmPushClientConfig()
165
+ if self.config.log_debug_verbose:
166
+ _logger.setLevel(logging.DEBUG)
167
+ self._http_client_session = http_client_session
168
+
169
+ self.reader: asyncio.StreamReader | None = None
170
+ self.writer: asyncio.StreamWriter | None = None
171
+ self.do_listen = False
172
+ self.sequential_error_counters: dict[ErrorType, int] = {}
173
+ self.log_warn_counters: dict[str, int] = {}
174
+
175
+ # reset variables
176
+ self.input_stream_id = 0
177
+ self.last_input_stream_id_reported = -1
178
+ self.first_message = True
179
+ self.last_login_time: float | None = None
180
+ self.last_message_time: float | None = None
181
+
182
+ self.run_state: FcmPushClientRunState = FcmPushClientRunState.CREATED
183
+ self.tasks: list[asyncio.Task] = []
184
+
185
+ self.reset_lock: asyncio.Lock | None = None
186
+ self.stopping_lock: asyncio.Lock | None = None
187
+
188
+ def _msg_str(self, msg: Message) -> str:
189
+ if self.config.log_debug_verbose:
190
+ return type(msg).__name__ + "\n" + MessageToJson(msg, indent=4)
191
+ return type(msg).__name__
192
+
193
+ def _log_verbose(self, msg: str, *args: object) -> None:
194
+ if self.config.log_debug_verbose:
195
+ _logger.debug(msg, *args)
196
+
197
+ def _log_warn_with_limit(self, msg: str, *args: object) -> None:
198
+ if msg not in self.log_warn_counters:
199
+ self.log_warn_counters[msg] = 0
200
+ if (
201
+ self.config.log_warn_limit
202
+ and self.config.log_warn_limit > self.log_warn_counters[msg]
203
+ ):
204
+ self.log_warn_counters[msg] += 1
205
+ _logger.warning(msg, *args)
206
+
207
+ async def _do_writer_close(self) -> None:
208
+ writer = self.writer
209
+ self.writer = None
210
+ if writer:
211
+ writer.close()
212
+ with contextlib.suppress(Exception):
213
+ await writer.wait_closed()
214
+
215
+ async def _reset(self) -> None:
216
+ if (
217
+ (self.reset_lock and self.reset_lock.locked())
218
+ or (self.stopping_lock and self.stopping_lock.locked())
219
+ or not self.do_listen
220
+ ):
221
+ return
222
+
223
+ async with self.reset_lock: # type: ignore[union-attr]
224
+ _logger.debug("Resetting connection")
225
+
226
+ self.run_state = FcmPushClientRunState.RESETTING
227
+
228
+ await self._do_writer_close()
229
+
230
+ now = time.time()
231
+ time_since_last_login = now - self.last_login_time # type: ignore[operator]
232
+ if time_since_last_login < self.config.reset_interval:
233
+ _logger.debug("%ss since last reset attempt.", time_since_last_login)
234
+ await asyncio.sleep(self.config.reset_interval - time_since_last_login)
235
+
236
+ _logger.debug("Reestablishing connection")
237
+ if not await self._connect_with_retry():
238
+ _logger.error(
239
+ "Unable to connect to MCS endpoint "
240
+ + "after %s tries, shutting down",
241
+ self.config.connection_retry_count,
242
+ )
243
+ self._terminate()
244
+ return
245
+ _logger.debug("Re-connected to ssl socket")
246
+
247
+ await self._login()
248
+
249
+ # protobuf variable length integers are encoded in base 128
250
+ # each byte contains 7 bits of the integer and the msb is set if there's
251
+ # more. pretty simple to implement
252
+ async def _read_varint32(self) -> int:
253
+ res = 0
254
+ shift = 0
255
+ while True:
256
+ r = await self.reader.readexactly(1) # type: ignore[union-attr]
257
+ (b,) = struct.unpack("B", r)
258
+ res |= (b & 0x7F) << shift
259
+ if (b & 0x80) == 0:
260
+ break
261
+ shift += 7
262
+ return res
263
+
264
+ @staticmethod
265
+ def _encode_varint32(x: int) -> bytes:
266
+ if x == 0:
267
+ return bytes(bytearray([0]))
268
+
269
+ res = bytearray([])
270
+ while x != 0:
271
+ b = x & 0x7F
272
+ x >>= 7
273
+ if x != 0:
274
+ b |= 0x80
275
+ res.append(b)
276
+ return bytes(res)
277
+
278
+ @staticmethod
279
+ def _make_packet(msg: Message, include_version: bool) -> bytes:
280
+ tag = MCS_MESSAGE_TAG[type(msg)]
281
+
282
+ header = bytearray([MCS_VERSION, tag]) if include_version else bytearray([tag])
283
+
284
+ payload = msg.SerializeToString()
285
+ buf = bytes(header) + FcmPushClient._encode_varint32(len(payload)) + payload
286
+ return buf
287
+
288
+ async def _send_msg(self, msg: Message) -> None:
289
+ self._log_verbose("Sending packet to server: %s", self._msg_str(msg))
290
+
291
+ buf = FcmPushClient._make_packet(msg, self.first_message)
292
+ self.writer.write(buf) # type: ignore[union-attr]
293
+ await self.writer.drain() # type: ignore[union-attr]
294
+
295
+ async def _receive_msg(self) -> Message | None:
296
+ if self.first_message:
297
+ r = await self.reader.readexactly(2) # type: ignore[union-attr]
298
+ version, tag = struct.unpack("BB", r)
299
+ if version < MCS_VERSION and version != 38:
300
+ raise RuntimeError(f"protocol version {version} unsupported")
301
+ self.first_message = False
302
+ else:
303
+ r = await self.reader.readexactly(1) # type: ignore[union-attr]
304
+ (tag,) = struct.unpack("B", r)
305
+ size = await self._read_varint32()
306
+
307
+ self._log_verbose(
308
+ "Received message with tag %s and size %s",
309
+ tag,
310
+ size,
311
+ )
312
+
313
+ if not size >= 0:
314
+ self._log_warn_with_limit("Unexpected message size %s", size)
315
+ return None
316
+
317
+ buf = await self.reader.readexactly(size) # type: ignore[union-attr]
318
+
319
+ msg_class = next(iter([c for c, t in MCS_MESSAGE_TAG.items() if t == tag]))
320
+ if not msg_class:
321
+ self._log_warn_with_limit("Unexpected message tag %s", tag)
322
+ return None
323
+ if isinstance(msg_class, str):
324
+ self._log_warn_with_limit("Unconfigured message class %s", msg_class)
325
+ return None
326
+
327
+ payload = msg_class() # type: ignore[operator]
328
+ payload.ParseFromString(buf)
329
+ self._log_verbose("Received payload: %s", self._msg_str(payload))
330
+
331
+ return payload
332
+
333
+ async def _login(self) -> None:
334
+ self.run_state = FcmPushClientRunState.STARTING_LOGIN
335
+
336
+ now = time.time()
337
+ self.input_stream_id = 0
338
+ self.last_input_stream_id_reported = -1
339
+ self.first_message = True
340
+ self.last_login_time = now
341
+
342
+ try:
343
+ android_id = self.credentials["gcm"]["android_id"] # type: ignore[index]
344
+ req = LoginRequest()
345
+ req.adaptive_heartbeat = False
346
+ req.auth_service = LoginRequest.ANDROID_ID # 2
347
+ req.auth_token = self.credentials["gcm"]["security_token"] # type: ignore[index]
348
+ req.id = self.fcm_config.chrome_version
349
+ req.domain = "mcs.android.com"
350
+ req.device_id = f"android-{int(android_id):x}"
351
+ req.network_type = 1
352
+ req.resource = android_id
353
+ req.user = android_id
354
+ req.use_rmq2 = True
355
+ req.setting.add(name="new_vc", value="1")
356
+ req.received_persistent_id.extend(self.persistent_ids)
357
+ if (
358
+ self.config.server_heartbeat_interval
359
+ and self.config.server_heartbeat_interval > 0
360
+ ):
361
+ req.heartbeat_stat.ip = ""
362
+ req.heartbeat_stat.timeout = True
363
+ req.heartbeat_stat.interval_ms = (
364
+ 1000 * self.config.server_heartbeat_interval
365
+ )
366
+
367
+ await self._send_msg(req)
368
+ _logger.debug("Sent login request")
369
+ except Exception as ex:
370
+ _logger.error("Received an exception logging in: %s", ex)
371
+ if self._try_increment_error_count(ErrorType.LOGIN):
372
+ await self._reset()
373
+
374
+ @staticmethod
375
+ def _decrypt_raw_data(
376
+ credentials: dict[str, dict[str, str]],
377
+ crypto_key_str: str,
378
+ salt_str: str,
379
+ raw_data: bytes,
380
+ ) -> bytes:
381
+ crypto_key = urlsafe_b64decode(crypto_key_str.encode("ascii"))
382
+ salt = urlsafe_b64decode(salt_str.encode("ascii"))
383
+ der_data_str = credentials["keys"]["private"]
384
+ der_data = urlsafe_b64decode(der_data_str.encode("ascii") + b"========")
385
+ secret_str = credentials["keys"]["secret"]
386
+ secret = urlsafe_b64decode(secret_str.encode("ascii") + b"========")
387
+ privkey = load_der_private_key(
388
+ der_data, password=None, backend=default_backend()
389
+ )
390
+ decrypted = http_decrypt(
391
+ raw_data,
392
+ salt=salt,
393
+ private_key=privkey,
394
+ dh=crypto_key,
395
+ version="aesgcm",
396
+ auth_secret=secret,
397
+ )
398
+ return decrypted
399
+
400
+ def _app_data_by_key(
401
+ self, p: DataMessageStanza, key: str, do_not_raise: bool = False
402
+ ) -> str:
403
+ for x in p.app_data:
404
+ if x.key == key:
405
+ return x.value
406
+
407
+ if do_not_raise:
408
+ return ""
409
+ raise RuntimeError(f"couldn't find in app_data {key}")
410
+
411
+ async def _handle_data_message(
412
+ self,
413
+ msg: DataMessageStanza,
414
+ ) -> None:
415
+ _logger.debug(
416
+ "Received data message Stream ID: %s, Last: %s, Status: %s",
417
+ msg.stream_id,
418
+ msg.last_stream_id_received,
419
+ msg.status,
420
+ )
421
+
422
+ if (
423
+ self._app_data_by_key(msg, "message_type", do_not_raise=True)
424
+ == "deleted_messages"
425
+ ):
426
+ # The deleted_messages message does not contain data.
427
+ return
428
+ crypto_key = self._app_data_by_key(msg, "crypto-key")[3:] # strip dh=
429
+ salt = self._app_data_by_key(msg, "encryption")[5:] # strip salt=
430
+ subtype = self._app_data_by_key(msg, "subtype")
431
+ if TYPE_CHECKING:
432
+ assert self.credentials
433
+ if subtype != self.credentials["gcm"]["app_id"]:
434
+ self._log_warn_with_limit(
435
+ "Subtype %s in data message does not match"
436
+ + "app id client was registered with %s",
437
+ subtype,
438
+ self.credentials["gcm"]["app_id"],
439
+ )
440
+ if not self.credentials:
441
+ return
442
+ decrypted = self._decrypt_raw_data(
443
+ self.credentials, crypto_key, salt, msg.raw_data
444
+ )
445
+ with contextlib_suppress(json.JSONDecodeError, ValueError):
446
+ decrypted_json = json.loads(decrypted.decode("utf-8"))
447
+
448
+ ret_val = decrypted_json if decrypted_json else decrypted
449
+ self._log_verbose(
450
+ "Decrypted data for message %s is: %s", msg.persistent_id, ret_val
451
+ )
452
+ try:
453
+ await self.callback(ret_val, msg.persistent_id, self.callback_context)
454
+ self._reset_error_count(ErrorType.NOTIFY)
455
+ except Exception:
456
+ _logger.exception("Unexpected exception calling notification callback\n")
457
+ self._try_increment_error_count(ErrorType.NOTIFY)
458
+
459
+ def _new_input_stream_id_available(self) -> bool:
460
+ return self.last_input_stream_id_reported != self.input_stream_id
461
+
462
+ def _get_input_stream_id(self) -> int:
463
+ self.last_input_stream_id_reported = self.input_stream_id
464
+ return self.input_stream_id
465
+
466
+ async def _handle_ping(self, p: HeartbeatPing) -> None:
467
+ #_logger.debug(
468
+ # "Received heartbeat ping, sending ack: Stream ID: %s, Last: %s, Status: %s",
469
+ # p.stream_id,
470
+ # p.last_stream_id_received,
471
+ # p.status,
472
+ #)
473
+ req = HeartbeatAck()
474
+
475
+ if self._new_input_stream_id_available():
476
+ req.last_stream_id_received = self._get_input_stream_id()
477
+
478
+ await self._send_msg(req)
479
+
480
+ async def _handle_iq(self, p: IqStanza) -> None:
481
+ if not p.extension:
482
+ self._log_warn_with_limit(
483
+ "Unexpected IqStanza id received with no extension", str(p)
484
+ )
485
+ return
486
+ if p.extension.id not in (12, 13):
487
+ self._log_warn_with_limit(
488
+ "Unexpected extension id received: %s", p.extension.id
489
+ )
490
+ return
491
+
492
+ async def _send_selective_ack(self, persistent_id: str) -> None:
493
+ iqs = IqStanza()
494
+ iqs.type = IqStanza.IqType.SET
495
+ iqs.id = ""
496
+ iqs.extension.id = MCS_SELECTIVE_ACK_ID
497
+ sa = SelectiveAck()
498
+ sa.id.extend([persistent_id])
499
+ iqs.extension.data = sa.SerializeToString()
500
+ _logger.debug("Sending selective ack for message id %s", persistent_id)
501
+ await self._send_msg(iqs)
502
+
503
+ async def _send_heartbeat(self) -> None:
504
+ req = HeartbeatPing()
505
+
506
+ if self._new_input_stream_id_available():
507
+ req.last_stream_id_received = self._get_input_stream_id()
508
+
509
+ await self._send_msg(req)
510
+ #_logger.debug("Sent heartbeat ping")
511
+
512
+ def _terminate(self) -> None:
513
+ self.run_state = FcmPushClientRunState.STOPPING
514
+
515
+ self.do_listen = False
516
+ current_task = asyncio.current_task()
517
+ for task in self.tasks:
518
+ if (
519
+ current_task != task and not task.done()
520
+ ): # cancel return if task is done so no need to check
521
+ task.cancel()
522
+
523
+ async def _do_monitor(self) -> None:
524
+ while self.do_listen:
525
+ await asyncio.sleep(self.config.monitor_interval)
526
+
527
+ if self.run_state == FcmPushClientRunState.STARTED:
528
+ # if server_heartbeat_interval is set and less than
529
+ # client_heartbeat_interval then the last_message_time
530
+ # will be within the client window if connected
531
+ if self.config.client_heartbeat_interval:
532
+ now = time.time()
533
+ if (
534
+ self.last_message_time + self.config.client_heartbeat_interval # type: ignore[operator]
535
+ < now
536
+ ):
537
+ await self._send_heartbeat()
538
+ await asyncio.sleep(self.config.heartbeat_ack_timeout)
539
+ now = time.time()
540
+ if ( # Check state hasn't changed during sleep
541
+ self.last_message_time # type: ignore[operator]
542
+ + self.config.client_heartbeat_interval
543
+ < now
544
+ and self.do_listen
545
+ and self.run_state == FcmPushClientRunState.STARTED
546
+ ):
547
+ await self._reset()
548
+ elif self.config.server_heartbeat_interval:
549
+ now = time.time()
550
+ if ( # We give the server 2 extra seconds
551
+ self.last_message_time + self.config.server_heartbeat_interval # type: ignore[operator]
552
+ < now - 2
553
+ ):
554
+ await self._reset()
555
+
556
+ def _reset_error_count(self, error_type: ErrorType) -> None:
557
+ self.sequential_error_counters[error_type] = 0
558
+
559
+ def _try_increment_error_count(self, error_type: ErrorType) -> bool:
560
+ if error_type not in self.sequential_error_counters:
561
+ self.sequential_error_counters[error_type] = 0
562
+
563
+ self.sequential_error_counters[error_type] += 1
564
+
565
+ if (
566
+ self.config.abort_on_sequential_error_count
567
+ and self.sequential_error_counters[error_type]
568
+ >= self.config.abort_on_sequential_error_count
569
+ ):
570
+ _logger.error(
571
+ "Shutting down push receiver due to "
572
+ + f"{self.sequential_error_counters[error_type]} sequential"
573
+ + f" errors of type {error_type}"
574
+ )
575
+ self._terminate()
576
+ return False
577
+ return True
578
+
579
+ async def _handle_message(self, msg: Message) -> None:
580
+ self.last_message_time = time.time()
581
+ self.input_stream_id += 1
582
+
583
+ if isinstance(msg, Close):
584
+ self._log_warn_with_limit("Server sent Close message, resetting")
585
+ if self._try_increment_error_count(ErrorType.CONNECTION):
586
+ await self._reset()
587
+ return
588
+
589
+ if isinstance(msg, LoginResponse):
590
+ if str(msg.error):
591
+ _logger.error("Received login error response: %s", msg)
592
+ if self._try_increment_error_count(ErrorType.LOGIN):
593
+ await self._reset()
594
+ else:
595
+ _logger.info("Successfully logged in to MCS endpoint")
596
+ self._reset_error_count(ErrorType.LOGIN)
597
+ self.run_state = FcmPushClientRunState.STARTED
598
+ self.persistent_ids = []
599
+ return
600
+
601
+ if isinstance(msg, DataMessageStanza):
602
+ await self._handle_data_message(msg)
603
+ self.persistent_ids.append(msg.persistent_id)
604
+ if self.config.send_selective_acknowledgements:
605
+ await self._send_selective_ack(msg.persistent_id)
606
+ elif isinstance(msg, HeartbeatPing):
607
+ await self._handle_ping(msg)
608
+ elif isinstance(msg, HeartbeatAck):
609
+ #_logger.debug("Received heartbeat ack: %s", msg)
610
+ pass
611
+ elif isinstance(msg, IqStanza):
612
+ pass
613
+ else:
614
+ self._log_warn_with_limit("Unexpected message type %s.", type(msg).__name__)
615
+ # Reset error count if a read has been successful
616
+ self._reset_error_count(ErrorType.READ)
617
+ self._reset_error_count(ErrorType.CONNECTION)
618
+
619
+ @staticmethod
620
+ async def _open_connection(
621
+ host: str, port: int, ssl_context: ssl.SSLContext
622
+ ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
623
+ return await asyncio.open_connection(host=host, port=port, ssl=ssl_context)
624
+
625
+ async def _connect(self) -> bool:
626
+ try:
627
+ loop = asyncio.get_running_loop()
628
+ # create_default_context() blocks the event loop
629
+ ssl_context = await loop.run_in_executor(None, ssl.create_default_context)
630
+ self.reader, self.writer = await self._open_connection(
631
+ host=MCS_HOST, port=MCS_PORT, ssl_context=ssl_context
632
+ )
633
+ _logger.debug("Connected to MCS endpoint (%s,%s)", MCS_HOST, MCS_PORT)
634
+ return True
635
+ except OSError as oex:
636
+ _logger.error(
637
+ "Could not connected to MCS endpoint (%s,%s): %s",
638
+ MCS_HOST,
639
+ MCS_PORT,
640
+ oex,
641
+ )
642
+ return False
643
+
644
+ async def _connect_with_retry(self) -> bool:
645
+ self.run_state = FcmPushClientRunState.STARTING_CONNECTION
646
+
647
+ trycount = 0
648
+ connected = False
649
+ while (
650
+ trycount < self.config.connection_retry_count
651
+ and not connected
652
+ and self.do_listen
653
+ ):
654
+ trycount += 1
655
+ connected = await self._connect()
656
+ if not connected:
657
+ sleep_time = (
658
+ self.config.start_seconds_before_retry_connect * trycount * trycount
659
+ )
660
+ _logger.info(
661
+ "Could not connect to MCS Endpoint on "
662
+ + "try %s, sleeping for %s seconds",
663
+ trycount,
664
+ sleep_time,
665
+ )
666
+ await asyncio.sleep(sleep_time)
667
+ if not connected:
668
+ _logger.error(
669
+ "Unable to connect to MCS endpoint after %s tries, aborting", trycount
670
+ )
671
+ return connected
672
+
673
+ async def _listen(self) -> None:
674
+ """listens for push notifications."""
675
+ if not await self._connect_with_retry():
676
+ return
677
+
678
+ try:
679
+ await self._login()
680
+
681
+ while self.do_listen:
682
+ try:
683
+ if self.run_state == FcmPushClientRunState.RESETTING:
684
+ await asyncio.sleep(1)
685
+ elif msg := await self._receive_msg():
686
+ await self._handle_message(msg)
687
+
688
+ except (OSError, EOFError) as osex:
689
+ if (
690
+ isinstance(
691
+ osex,
692
+ (
693
+ ConnectionResetError,
694
+ TimeoutError,
695
+ asyncio.IncompleteReadError,
696
+ ssl.SSLError,
697
+ ),
698
+ )
699
+ and self.run_state == FcmPushClientRunState.RESETTING
700
+ ):
701
+ if (
702
+ isinstance(osex, ssl.SSLError) # pylint: disable=no-member
703
+ and osex.reason != "APPLICATION_DATA_AFTER_CLOSE_NOTIFY"
704
+ ):
705
+ self._log_warn_with_limit(
706
+ "Unexpected SSLError reason during reset of %s",
707
+ osex.reason,
708
+ )
709
+ else:
710
+ self._log_verbose(
711
+ "Expected read error during reset: %s",
712
+ type(osex).__name__,
713
+ )
714
+ else:
715
+ _logger.exception("Unexpected exception during read\n")
716
+ if self._try_increment_error_count(ErrorType.CONNECTION):
717
+ await self._reset()
718
+ except Exception as ex:
719
+ _logger.error(
720
+ "Unknown error: %s, shutting down FcmPushClient.\n%s",
721
+ ex,
722
+ traceback.format_exc(),
723
+ )
724
+ self._terminate()
725
+ finally:
726
+ await self._do_writer_close()
727
+
728
+ async def checkin_or_register(self, fcmCredentialsFileName) -> str:
729
+ """Check in if you have credentials otherwise register as a new client.
730
+
731
+ :param sender_id: sender id identifying push service you are connecting to.
732
+ :param app_id: identifier for your application.
733
+ :return: The FCM token which is used to identify you with the push end
734
+ point application.
735
+ """
736
+ self.register = FcmRegister(
737
+ self.fcm_config,
738
+ self.credentials,
739
+ self.credentials_updated_callback,
740
+ http_client_session=self._http_client_session,
741
+ )
742
+ self.credentials = await self.register.checkin_or_register(fcmCredentialsFileName)
743
+ # await self.register.fcm_refresh_install()
744
+ await self.register.close()
745
+ return self.credentials["fcm"]["registration"]["token"]
746
+
747
+ async def start(self) -> None:
748
+ """Connect to FCM and start listening for push notifications."""
749
+ self.reset_lock = asyncio.Lock()
750
+ self.stopping_lock = asyncio.Lock()
751
+ self.do_listen = True
752
+ self.run_state = FcmPushClientRunState.STARTING_TASKS
753
+ try:
754
+ self.tasks = [
755
+ asyncio.create_task(self._listen()),
756
+ asyncio.create_task(self._do_monitor()),
757
+ ]
758
+ except Exception as ex:
759
+ _logger.error("Unexpected error running FcmPushClient: %s", ex)
760
+
761
+ async def stop(self) -> None:
762
+ if (
763
+ self.stopping_lock
764
+ and self.stopping_lock.locked()
765
+ or self.run_state
766
+ in (
767
+ FcmPushClientRunState.STOPPING,
768
+ FcmPushClientRunState.STOPPED,
769
+ )
770
+ ):
771
+ return
772
+
773
+ async with self.stopping_lock: # type: ignore[union-attr]
774
+ try:
775
+ self.run_state = FcmPushClientRunState.STOPPING
776
+
777
+ self.do_listen = False
778
+
779
+ for task in self.tasks:
780
+ if not task.done():
781
+ task.cancel()
782
+
783
+ finally:
784
+ self.run_state = FcmPushClientRunState.STOPPED
785
+ self.fcm_thread = None
786
+ self.listen_event_loop = None
787
+
788
+ def is_started(self) -> bool:
789
+ return self.run_state == FcmPushClientRunState.STARTED
790
+
791
+ async def send_message(self, raw_data: bytes, persistent_id: str) -> None:
792
+ """Not implemented, does nothing atm."""
793
+ dms = DataMessageStanza()
794
+ dms.persistent_id = persistent_id
795
+
796
+ # Not supported yet