pycupra 0.1.11__py3-2ndver-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,812 @@
1
+ from __future__ import annotations
2
+
3
+ """Functions to receive firebase cloud messaging notifications """
4
+ """Taken from https://github.com/sdb9696/firebase-messaging with small modifications"""
5
+
6
+ import asyncio
7
+ import contextlib
8
+ import json
9
+ import logging
10
+ import ssl
11
+ import struct
12
+ import time
13
+ import traceback
14
+ from base64 import urlsafe_b64decode
15
+ from contextlib import suppress as contextlib_suppress
16
+ from dataclasses import dataclass
17
+ from enum import Enum
18
+ from typing import TYPE_CHECKING, Any, Callable
19
+
20
+ from aiohttp import ClientSession
21
+ from cryptography.hazmat.backends import default_backend
22
+ from cryptography.hazmat.primitives.serialization import load_der_private_key
23
+ from google.protobuf.json_format import MessageToJson
24
+ from google.protobuf.message import Message
25
+ from http_ece import decrypt as http_decrypt # type: ignore[import-untyped]
26
+
27
+ from .const import (
28
+ MCS_HOST,
29
+ MCS_PORT,
30
+ MCS_SELECTIVE_ACK_ID,
31
+ MCS_VERSION,
32
+ )
33
+ from .fcmregister import FcmRegister, FcmRegisterConfig
34
+ from .mcs_pb2 import ( # pylint: disable=no-name-in-module
35
+ Close,
36
+ DataMessageStanza,
37
+ HeartbeatAck,
38
+ HeartbeatPing,
39
+ IqStanza,
40
+ LoginRequest,
41
+ LoginResponse,
42
+ SelectiveAck,
43
+ StreamErrorStanza,
44
+ )
45
+
46
+ _logger = logging.getLogger(__name__)
47
+
48
+ OnNotificationCallable = Callable[[dict[str, Any], str, Any], None]
49
+ CredentialsUpdatedCallable = Callable[[dict[str, Any]], None]
50
+
51
+ # MCS Message Types and Tags
52
+ MCS_MESSAGE_TAG = {
53
+ HeartbeatPing: 0,
54
+ HeartbeatAck: 1,
55
+ LoginRequest: 2,
56
+ LoginResponse: 3,
57
+ Close: 4,
58
+ "MessageStanza": 5,
59
+ "PresenceStanza": 6,
60
+ IqStanza: 7,
61
+ DataMessageStanza: 8,
62
+ "BatchPresenceStanza": 9,
63
+ StreamErrorStanza: 10,
64
+ "HttpRequest": 11,
65
+ "HttpResponse": 12,
66
+ "BindAccountRequest": 13,
67
+ "BindAccountResponse": 14,
68
+ "TalkMetadata": 15,
69
+ }
70
+
71
+
72
+ class ErrorType(Enum):
73
+ CONNECTION = 1
74
+ READ = 2
75
+ LOGIN = 3
76
+ NOTIFY = 4
77
+
78
+
79
+ class FcmPushClientRunState(Enum):
80
+ CREATED = (1,)
81
+ STARTING_TASKS = (2,)
82
+ STARTING_CONNECTION = (3,)
83
+ STARTING_LOGIN = (4,)
84
+ STARTED = (5,)
85
+ RESETTING = (6,)
86
+ STOPPING = (7,)
87
+ STOPPED = (8,)
88
+
89
+
90
+ @dataclass
91
+ class FcmPushClientConfig: # pylint:disable=too-many-instance-attributes
92
+ """Class to provide configuration to
93
+ :class:`firebase_messaging.FcmPushClientConfig`.FcmPushClient."""
94
+
95
+ server_heartbeat_interval: int | None = 30 # original value was 10
96
+ """Time in seconds to request the server to send heartbeats"""
97
+
98
+ client_heartbeat_interval: int | None = 40 # original value was 20
99
+ """Time in seconds to send heartbeats to the server"""
100
+
101
+ send_selective_acknowledgements: bool = True
102
+ """True to send selective acknowledgements for each message received.
103
+ Currently if false the client does not send any acknowledgements."""
104
+
105
+ connection_retry_count: int = 5
106
+ """Number of times to retry the connection before giving up."""
107
+
108
+ start_seconds_before_retry_connect: float = 3
109
+ """Time in seconds to wait before attempting to retry
110
+ the connection after failure."""
111
+
112
+ reset_interval: float = 3
113
+ """Time in seconds to wait between resets after errors or disconnection."""
114
+
115
+ heartbeat_ack_timeout: float = 5
116
+ """Time in seconds to wait for a heartbeat ack before resetting."""
117
+
118
+ abort_on_sequential_error_count: int | None = 3
119
+ """Number of sequential errors of the same time to wait before aborting.
120
+ If set to None the client will not abort."""
121
+
122
+ monitor_interval: float = 1
123
+ """Time in seconds for the monitor task to fire and check for heartbeats,
124
+ stale connections and shut down of the main event loop."""
125
+
126
+ log_warn_limit: int | None = 5
127
+ """Number of times to log specific warning messages before going silent for
128
+ a specific warning type."""
129
+
130
+ log_debug_verbose: bool = False
131
+ """Set to True to log all message info including tokens."""
132
+
133
+
134
+ class FcmPushClient: # pylint:disable=too-many-instance-attributes
135
+ """Client that connects to Firebase Cloud Messaging and receives messages.
136
+
137
+ :param credentials: credentials object returned by register()
138
+ :param credentials_updated_callback: callback when new credentials are
139
+ created to allow client to store them
140
+ :param received_persistent_ids: any persistent id's you already received.
141
+ :param config: configuration class of
142
+ :class:`firebase_messaging.FcmPushClientConfig`
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ callback: Callable[[dict, str, Any | None], None],
148
+ fcm_config: FcmRegisterConfig,
149
+ credentials: dict | None = None,
150
+ credentials_updated_callback: CredentialsUpdatedCallable | None = None,
151
+ *,
152
+ callback_context: object | None = None,
153
+ received_persistent_ids: list[str] | None = None,
154
+ config: FcmPushClientConfig | None = None,
155
+ http_client_session: ClientSession | None = None,
156
+ ):
157
+ """Initializes the receiver."""
158
+ self.callback = callback
159
+ self.callback_context = callback_context
160
+ self.fcm_config = fcm_config
161
+ self.credentials = credentials
162
+ self.credentials_updated_callback = credentials_updated_callback
163
+ self.persistent_ids = received_persistent_ids if received_persistent_ids else []
164
+ self.config = config if config else FcmPushClientConfig()
165
+ if self.config.log_debug_verbose:
166
+ _logger.setLevel(logging.DEBUG)
167
+ self._http_client_session = http_client_session
168
+
169
+ self.reader: asyncio.StreamReader | None = None
170
+ self.writer: asyncio.StreamWriter | None = None
171
+ self.do_listen = False
172
+ self.sequential_error_counters: dict[ErrorType, int] = {}
173
+ self.log_warn_counters: dict[str, int] = {}
174
+
175
+ # reset variables
176
+ self.input_stream_id = 0
177
+ self.last_input_stream_id_reported = -1
178
+ self.first_message = True
179
+ self.last_login_time: float | None = None
180
+ self.last_message_time: float | None = None
181
+
182
+ self.run_state: FcmPushClientRunState = FcmPushClientRunState.CREATED
183
+ self.tasks: list[asyncio.Task] = []
184
+
185
+ self.reset_lock: asyncio.Lock | None = None
186
+ self.stopping_lock: asyncio.Lock | None = None
187
+
188
+ def _msg_str(self, msg: Message) -> str:
189
+ if self.config.log_debug_verbose:
190
+ return type(msg).__name__ + "\n" + MessageToJson(msg, indent=4)
191
+ return type(msg).__name__
192
+
193
+ def _log_verbose(self, msg: str, *args: object) -> None:
194
+ if self.config.log_debug_verbose:
195
+ _logger.debug(msg, *args)
196
+
197
+ def _log_warn_with_limit(self, msg: str, *args: object) -> None:
198
+ if msg not in self.log_warn_counters:
199
+ self.log_warn_counters[msg] = 0
200
+ if (
201
+ self.config.log_warn_limit
202
+ and self.config.log_warn_limit > self.log_warn_counters[msg]
203
+ ):
204
+ self.log_warn_counters[msg] += 1
205
+ _logger.warning(msg, *args)
206
+
207
+ async def _do_writer_close(self) -> None:
208
+ writer = self.writer
209
+ self.writer = None
210
+ if writer:
211
+ writer.close()
212
+ with contextlib.suppress(Exception):
213
+ await writer.wait_closed()
214
+
215
+ async def _reset(self) -> None:
216
+ if (
217
+ (self.reset_lock and self.reset_lock.locked())
218
+ or (self.stopping_lock and self.stopping_lock.locked())
219
+ or not self.do_listen
220
+ ):
221
+ _logger.debug(f"In _reset. reset_lock={self.reset_lock}, reset_lock.locked={self.reset_lock.locked()}, stopping_lock={self.stopping_lock}, stopping_lock.locked={self.stopping_lock.locked()}, do_listen={self.do_listen}")
222
+ return
223
+
224
+ async with self.reset_lock: # type: ignore[union-attr]
225
+ _logger.debug("Resetting connection")
226
+
227
+ self.run_state = FcmPushClientRunState.RESETTING
228
+
229
+ await self._do_writer_close()
230
+
231
+ now = time.time()
232
+ time_since_last_login = now - self.last_login_time # type: ignore[operator]
233
+ if time_since_last_login < self.config.reset_interval:
234
+ _logger.debug("%ss since last reset attempt.", time_since_last_login)
235
+ await asyncio.sleep(self.config.reset_interval - time_since_last_login)
236
+
237
+ _logger.debug("Reestablishing connection")
238
+ if not await self._connect_with_retry():
239
+ _logger.error(
240
+ "Unable to connect to MCS endpoint "
241
+ + "after %s tries, shutting down",
242
+ self.config.connection_retry_count,
243
+ )
244
+ self._terminate()
245
+ return
246
+ _logger.debug("Re-connected to ssl socket")
247
+
248
+ await self._login()
249
+
250
+ # protobuf variable length integers are encoded in base 128
251
+ # each byte contains 7 bits of the integer and the msb is set if there's
252
+ # more. pretty simple to implement
253
+ async def _read_varint32(self) -> int:
254
+ res = 0
255
+ shift = 0
256
+ while True:
257
+ r = await self.reader.readexactly(1) # type: ignore[union-attr]
258
+ (b,) = struct.unpack("B", r)
259
+ res |= (b & 0x7F) << shift
260
+ if (b & 0x80) == 0:
261
+ break
262
+ shift += 7
263
+ return res
264
+
265
+ @staticmethod
266
+ def _encode_varint32(x: int) -> bytes:
267
+ if x == 0:
268
+ return bytes(bytearray([0]))
269
+
270
+ res = bytearray([])
271
+ while x != 0:
272
+ b = x & 0x7F
273
+ x >>= 7
274
+ if x != 0:
275
+ b |= 0x80
276
+ res.append(b)
277
+ return bytes(res)
278
+
279
+ @staticmethod
280
+ def _make_packet(msg: Message, include_version: bool) -> bytes:
281
+ tag = MCS_MESSAGE_TAG[type(msg)]
282
+
283
+ header = bytearray([MCS_VERSION, tag]) if include_version else bytearray([tag])
284
+
285
+ payload = msg.SerializeToString()
286
+ buf = bytes(header) + FcmPushClient._encode_varint32(len(payload)) + payload
287
+ return buf
288
+
289
+ async def _send_msg(self, msg: Message) -> None:
290
+ self._log_verbose("Sending packet to server: %s", self._msg_str(msg))
291
+
292
+ buf = FcmPushClient._make_packet(msg, self.first_message)
293
+ self.writer.write(buf) # type: ignore[union-attr]
294
+ await self.writer.drain() # type: ignore[union-attr]
295
+
296
+ async def _receive_msg(self) -> Message | None:
297
+ if self.first_message:
298
+ r = await self.reader.readexactly(2) # type: ignore[union-attr]
299
+ version, tag = struct.unpack("BB", r)
300
+ if version < MCS_VERSION and version != 38:
301
+ raise RuntimeError(f"protocol version {version} unsupported")
302
+ self.first_message = False
303
+ else:
304
+ r = await self.reader.readexactly(1) # type: ignore[union-attr]
305
+ (tag,) = struct.unpack("B", r)
306
+ size = await self._read_varint32()
307
+
308
+ self._log_verbose(
309
+ "Received message with tag %s and size %s",
310
+ tag,
311
+ size,
312
+ )
313
+
314
+ if not size >= 0:
315
+ self._log_warn_with_limit("Unexpected message size %s", size)
316
+ return None
317
+
318
+ buf = await self.reader.readexactly(size) # type: ignore[union-attr]
319
+
320
+ msg_class = next(iter([c for c, t in MCS_MESSAGE_TAG.items() if t == tag]))
321
+ if not msg_class:
322
+ self._log_warn_with_limit("Unexpected message tag %s", tag)
323
+ return None
324
+ if isinstance(msg_class, str):
325
+ self._log_warn_with_limit("Unconfigured message class %s", msg_class)
326
+ return None
327
+
328
+ payload = msg_class() # type: ignore[operator]
329
+ payload.ParseFromString(buf)
330
+ self._log_verbose("Received payload: %s", self._msg_str(payload))
331
+
332
+ return payload
333
+
334
+ async def _login(self) -> None:
335
+ self.run_state = FcmPushClientRunState.STARTING_LOGIN
336
+
337
+ now = time.time()
338
+ self.input_stream_id = 0
339
+ self.last_input_stream_id_reported = -1
340
+ self.first_message = True
341
+ self.last_login_time = now
342
+
343
+ try:
344
+ android_id = self.credentials["gcm"]["android_id"] # type: ignore[index]
345
+ req = LoginRequest()
346
+ req.adaptive_heartbeat = False
347
+ req.auth_service = LoginRequest.ANDROID_ID # 2
348
+ req.auth_token = self.credentials["gcm"]["security_token"] # type: ignore[index]
349
+ req.id = self.fcm_config.chrome_version
350
+ req.domain = "mcs.android.com"
351
+ req.device_id = f"android-{int(android_id):x}"
352
+ req.network_type = 1
353
+ req.resource = android_id
354
+ req.user = android_id
355
+ req.use_rmq2 = True
356
+ req.setting.add(name="new_vc", value="1")
357
+ req.received_persistent_id.extend(self.persistent_ids)
358
+ if (
359
+ self.config.server_heartbeat_interval
360
+ and self.config.server_heartbeat_interval > 0
361
+ ):
362
+ req.heartbeat_stat.ip = ""
363
+ req.heartbeat_stat.timeout = True
364
+ req.heartbeat_stat.interval_ms = (
365
+ 1000 * self.config.server_heartbeat_interval
366
+ )
367
+
368
+ await self._send_msg(req)
369
+ _logger.debug("Sent login request")
370
+ except Exception as ex:
371
+ _logger.error("Received an exception logging in: %s", ex)
372
+ if self._try_increment_error_count(ErrorType.LOGIN):
373
+ await self._reset()
374
+
375
+ @staticmethod
376
+ def _decrypt_raw_data(
377
+ credentials: dict[str, dict[str, str]],
378
+ crypto_key_str: str,
379
+ salt_str: str,
380
+ raw_data: bytes,
381
+ ) -> bytes:
382
+ crypto_key = urlsafe_b64decode(crypto_key_str.encode("ascii"))
383
+ salt = urlsafe_b64decode(salt_str.encode("ascii"))
384
+ der_data_str = credentials["keys"]["private"]
385
+ der_data = urlsafe_b64decode(der_data_str.encode("ascii") + b"========")
386
+ secret_str = credentials["keys"]["secret"]
387
+ secret = urlsafe_b64decode(secret_str.encode("ascii") + b"========")
388
+ privkey = load_der_private_key(
389
+ der_data, password=None, backend=default_backend()
390
+ )
391
+ decrypted = http_decrypt(
392
+ raw_data,
393
+ salt=salt,
394
+ private_key=privkey,
395
+ dh=crypto_key,
396
+ version="aesgcm",
397
+ auth_secret=secret,
398
+ )
399
+ return decrypted
400
+
401
+ def _app_data_by_key(
402
+ self, p: DataMessageStanza, key: str, do_not_raise: bool = False
403
+ ) -> str:
404
+ for x in p.app_data:
405
+ if x.key == key:
406
+ return x.value
407
+
408
+ if do_not_raise:
409
+ return ""
410
+ raise RuntimeError(f"couldn't find in app_data {key}")
411
+
412
+ async def _handle_data_message(
413
+ self,
414
+ msg: DataMessageStanza,
415
+ ) -> None:
416
+ _logger.debug(
417
+ "Received data message Stream ID: %s, Last: %s, Status: %s",
418
+ msg.stream_id,
419
+ msg.last_stream_id_received,
420
+ msg.status,
421
+ )
422
+
423
+ if (
424
+ self._app_data_by_key(msg, "message_type", do_not_raise=True)
425
+ == "deleted_messages"
426
+ ):
427
+ # The deleted_messages message does not contain data.
428
+ return
429
+ crypto_key = self._app_data_by_key(msg, "crypto-key")[3:] # strip dh=
430
+ salt = self._app_data_by_key(msg, "encryption")[5:] # strip salt=
431
+ subtype = self._app_data_by_key(msg, "subtype")
432
+ if TYPE_CHECKING:
433
+ assert self.credentials
434
+ if subtype != self.credentials["gcm"]["app_id"]:
435
+ self._log_warn_with_limit(
436
+ "Subtype %s in data message does not match"
437
+ + "app id client was registered with %s",
438
+ subtype,
439
+ self.credentials["gcm"]["app_id"],
440
+ )
441
+ if not self.credentials:
442
+ return
443
+ decrypted = self._decrypt_raw_data(
444
+ self.credentials, crypto_key, salt, msg.raw_data
445
+ )
446
+ with contextlib_suppress(json.JSONDecodeError, ValueError):
447
+ decrypted_json = json.loads(decrypted.decode("utf-8"))
448
+
449
+ ret_val = decrypted_json if decrypted_json else decrypted
450
+ self._log_verbose(
451
+ "Decrypted data for message %s is: %s", msg.persistent_id, ret_val
452
+ )
453
+ try:
454
+ await self.callback(ret_val, msg.persistent_id, self.callback_context)
455
+ self._reset_error_count(ErrorType.NOTIFY)
456
+ except Exception:
457
+ _logger.exception("Unexpected exception calling notification callback\n")
458
+ self._try_increment_error_count(ErrorType.NOTIFY)
459
+
460
+ def _new_input_stream_id_available(self) -> bool:
461
+ return self.last_input_stream_id_reported != self.input_stream_id
462
+
463
+ def _get_input_stream_id(self) -> int:
464
+ self.last_input_stream_id_reported = self.input_stream_id
465
+ return self.input_stream_id
466
+
467
+ async def _handle_ping(self, p: HeartbeatPing) -> None:
468
+ #_logger.debug(
469
+ # "Received heartbeat ping, sending ack: Stream ID: %s, Last: %s, Status: %s",
470
+ # p.stream_id,
471
+ # p.last_stream_id_received,
472
+ # p.status,
473
+ #)
474
+ req = HeartbeatAck()
475
+
476
+ if self._new_input_stream_id_available():
477
+ req.last_stream_id_received = self._get_input_stream_id()
478
+
479
+ await self._send_msg(req)
480
+
481
+ async def _handle_iq(self, p: IqStanza) -> None:
482
+ if not p.extension:
483
+ self._log_warn_with_limit(
484
+ "Unexpected IqStanza id received with no extension", str(p)
485
+ )
486
+ return
487
+ if p.extension.id not in (12, 13):
488
+ self._log_warn_with_limit(
489
+ "Unexpected extension id received: %s", p.extension.id
490
+ )
491
+ return
492
+
493
+ async def _send_selective_ack(self, persistent_id: str) -> None:
494
+ iqs = IqStanza()
495
+ iqs.type = IqStanza.IqType.SET
496
+ iqs.id = ""
497
+ iqs.extension.id = MCS_SELECTIVE_ACK_ID
498
+ sa = SelectiveAck()
499
+ sa.id.extend([persistent_id])
500
+ iqs.extension.data = sa.SerializeToString()
501
+ _logger.debug("Sending selective ack for message id %s", persistent_id)
502
+ await self._send_msg(iqs)
503
+
504
+ async def _send_heartbeat(self) -> None:
505
+ req = HeartbeatPing()
506
+
507
+ if self._new_input_stream_id_available():
508
+ req.last_stream_id_received = self._get_input_stream_id()
509
+
510
+ await self._send_msg(req)
511
+ #_logger.debug("Sent heartbeat ping")
512
+
513
+ def _terminate(self) -> None:
514
+ self.run_state = FcmPushClientRunState.STOPPING
515
+
516
+ self.do_listen = False
517
+ current_task = asyncio.current_task()
518
+ for task in self.tasks:
519
+ if (
520
+ current_task != task and not task.done()
521
+ ): # cancel return if task is done so no need to check
522
+ task.cancel()
523
+
524
+ async def _do_monitor(self) -> None:
525
+ while self.do_listen:
526
+ await asyncio.sleep(self.config.monitor_interval)
527
+
528
+ if self.run_state == FcmPushClientRunState.STARTED:
529
+ # if server_heartbeat_interval is set and less than
530
+ # client_heartbeat_interval then the last_message_time
531
+ # will be within the client window if connected
532
+ if self.config.client_heartbeat_interval:
533
+ now = time.time()
534
+ if (
535
+ self.last_message_time + self.config.client_heartbeat_interval # type: ignore[operator]
536
+ < now
537
+ ):
538
+ await self._send_heartbeat()
539
+ await asyncio.sleep(self.config.heartbeat_ack_timeout)
540
+ now = time.time()
541
+ if ( # Check state hasn't changed during sleep
542
+ self.last_message_time # type: ignore[operator]
543
+ + self.config.client_heartbeat_interval
544
+ < now
545
+ and self.do_listen
546
+ and self.run_state == FcmPushClientRunState.STARTED
547
+ ):
548
+ await self._reset()
549
+ elif self.config.server_heartbeat_interval:
550
+ now = time.time()
551
+ if ( # We give the server 2 extra seconds
552
+ self.last_message_time + self.config.server_heartbeat_interval # type: ignore[operator]
553
+ < now - 2
554
+ ):
555
+ await self._reset()
556
+
557
+ def _reset_error_count(self, error_type: ErrorType) -> None:
558
+ self.sequential_error_counters[error_type] = 0
559
+
560
+ def _try_increment_error_count(self, error_type: ErrorType) -> bool:
561
+ if error_type not in self.sequential_error_counters:
562
+ self.sequential_error_counters[error_type] = 0
563
+
564
+ self.sequential_error_counters[error_type] += 1
565
+
566
+ if (
567
+ self.config.abort_on_sequential_error_count
568
+ and self.sequential_error_counters[error_type]
569
+ >= self.config.abort_on_sequential_error_count
570
+ ):
571
+ _logger.error(
572
+ "Shutting down push receiver due to "
573
+ + f"{self.sequential_error_counters[error_type]} sequential"
574
+ + f" errors of type {error_type}"
575
+ )
576
+ self._terminate()
577
+ return False
578
+ return True
579
+
580
+ async def _handle_message(self, msg: Message) -> None:
581
+ self.last_message_time = time.time()
582
+ self.input_stream_id += 1
583
+
584
+ if isinstance(msg, Close):
585
+ self._log_warn_with_limit("Server sent Close message, resetting")
586
+ if self._try_increment_error_count(ErrorType.CONNECTION):
587
+ await self._reset()
588
+ return
589
+
590
+ if isinstance(msg, LoginResponse):
591
+ if str(msg.error):
592
+ _logger.error("Received login error response: %s", msg)
593
+ if self._try_increment_error_count(ErrorType.LOGIN):
594
+ await self._reset()
595
+ else:
596
+ _logger.info("Successfully logged in to MCS endpoint")
597
+ self._reset_error_count(ErrorType.LOGIN)
598
+ self.run_state = FcmPushClientRunState.STARTED
599
+ self.persistent_ids = []
600
+ return
601
+
602
+ if isinstance(msg, DataMessageStanza):
603
+ #await self._handle_data_message(msg)
604
+ #self.persistent_ids.append(msg.persistent_id)
605
+ #if self.config.send_selective_acknowledgements:
606
+ # await self._send_selective_ack(msg.persistent_id)
607
+ if self.config.send_selective_acknowledgements:
608
+ # As handle_data_message with the callback of onNotification can take some time, send_selective_ack is called in parallel
609
+ await asyncio.gather(
610
+ self._handle_data_message(msg),
611
+ self._send_selective_ack(msg.persistent_id),
612
+ return_exceptions=True
613
+ )
614
+ self.persistent_ids.append(msg.persistent_id),
615
+ else:
616
+ await self._handle_data_message(msg)
617
+ self.persistent_ids.append(msg.persistent_id)
618
+ elif isinstance(msg, HeartbeatPing):
619
+ await self._handle_ping(msg)
620
+ elif isinstance(msg, HeartbeatAck):
621
+ #_logger.debug("Received heartbeat ack: %s", msg)
622
+ pass
623
+ elif isinstance(msg, IqStanza):
624
+ pass
625
+ else:
626
+ self._log_warn_with_limit("Unexpected message type %s.", type(msg).__name__)
627
+ # Reset error count if a read has been successful
628
+ self._reset_error_count(ErrorType.READ)
629
+ self._reset_error_count(ErrorType.CONNECTION)
630
+
631
+ @staticmethod
632
+ async def _open_connection(
633
+ host: str, port: int, ssl_context: ssl.SSLContext
634
+ ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
635
+ return await asyncio.open_connection(host=host, port=port, ssl=ssl_context)
636
+
637
+ async def _connect(self) -> bool:
638
+ try:
639
+ loop = asyncio.get_running_loop()
640
+ # create_default_context() blocks the event loop
641
+ ssl_context = await loop.run_in_executor(None, ssl.create_default_context)
642
+ self.reader, self.writer = await self._open_connection(
643
+ host=MCS_HOST, port=MCS_PORT, ssl_context=ssl_context
644
+ )
645
+ _logger.debug("Connected to MCS endpoint (%s,%s)", MCS_HOST, MCS_PORT)
646
+ return True
647
+ except OSError as oex:
648
+ _logger.error(
649
+ "Could not connected to MCS endpoint (%s,%s): %s",
650
+ MCS_HOST,
651
+ MCS_PORT,
652
+ oex,
653
+ )
654
+ return False
655
+
656
+ async def _connect_with_retry(self) -> bool:
657
+ self.run_state = FcmPushClientRunState.STARTING_CONNECTION
658
+
659
+ trycount = 0
660
+ connected = False
661
+ while (
662
+ trycount < self.config.connection_retry_count
663
+ and not connected
664
+ and self.do_listen
665
+ ):
666
+ trycount += 1
667
+ connected = await self._connect()
668
+ if not connected:
669
+ sleep_time = (
670
+ self.config.start_seconds_before_retry_connect * trycount * trycount
671
+ )
672
+ _logger.info(
673
+ "Could not connect to MCS Endpoint on "
674
+ + "try %s, sleeping for %s seconds",
675
+ trycount,
676
+ sleep_time,
677
+ )
678
+ await asyncio.sleep(sleep_time)
679
+ if not connected:
680
+ _logger.error(
681
+ "Unable to connect to MCS endpoint after %s tries, aborting", trycount
682
+ )
683
+ return connected
684
+
685
+ async def _listen(self) -> None:
686
+ """listens for push notifications."""
687
+ if not await self._connect_with_retry():
688
+ return
689
+
690
+ try:
691
+ await self._login()
692
+
693
+ while self.do_listen:
694
+ try:
695
+ if self.run_state == FcmPushClientRunState.RESETTING:
696
+ await asyncio.sleep(1)
697
+ elif msg := await self._receive_msg():
698
+ await self._handle_message(msg)
699
+
700
+ except (OSError, EOFError) as osex:
701
+ if (
702
+ isinstance(
703
+ osex,
704
+ (
705
+ ConnectionResetError,
706
+ TimeoutError,
707
+ asyncio.IncompleteReadError,
708
+ ssl.SSLError,
709
+ ),
710
+ )
711
+ and self.run_state == FcmPushClientRunState.RESETTING
712
+ ):
713
+ if (
714
+ isinstance(osex, ssl.SSLError) # pylint: disable=no-member
715
+ and osex.reason != "APPLICATION_DATA_AFTER_CLOSE_NOTIFY"
716
+ ):
717
+ self._log_warn_with_limit(
718
+ "Unexpected SSLError reason during reset of %s",
719
+ osex.reason,
720
+ )
721
+ else:
722
+ self._log_verbose(
723
+ "Expected read error during reset: %s",
724
+ type(osex).__name__,
725
+ )
726
+ else:
727
+ _logger.exception("Unexpected exception during read\n")
728
+ if self._try_increment_error_count(ErrorType.CONNECTION):
729
+ _logger.debug("Calling reset()\n")
730
+ await self._reset()
731
+ else:
732
+ _logger.debug("Not calling reset()\n")
733
+ except Exception as ex:
734
+ _logger.error(
735
+ "Unknown error: %s, shutting down FcmPushClient.\n%s",
736
+ ex,
737
+ traceback.format_exc(),
738
+ )
739
+ self._terminate()
740
+ finally:
741
+ await self._do_writer_close()
742
+
743
+ async def checkin_or_register(self, fcmCredentialsFileName) -> str:
744
+ """Check in if you have credentials otherwise register as a new client.
745
+
746
+ :param sender_id: sender id identifying push service you are connecting to.
747
+ :param app_id: identifier for your application.
748
+ :return: The FCM token which is used to identify you with the push end
749
+ point application.
750
+ """
751
+ self.register = FcmRegister(
752
+ self.fcm_config,
753
+ self.credentials,
754
+ self.credentials_updated_callback,
755
+ http_client_session=self._http_client_session,
756
+ )
757
+ self.credentials = await self.register.checkin_or_register(fcmCredentialsFileName)
758
+ # await self.register.fcm_refresh_install()
759
+ await self.register.close()
760
+ return self.credentials["fcm"]["registration"]["token"]
761
+
762
+ async def start(self) -> None:
763
+ """Connect to FCM and start listening for push notifications."""
764
+ self.reset_lock = asyncio.Lock()
765
+ self.stopping_lock = asyncio.Lock()
766
+ self.do_listen = True
767
+ self.run_state = FcmPushClientRunState.STARTING_TASKS
768
+ try:
769
+ self.tasks = [
770
+ asyncio.create_task(self._listen()),
771
+ asyncio.create_task(self._do_monitor()),
772
+ ]
773
+ except Exception as ex:
774
+ _logger.error("Unexpected error running FcmPushClient: %s", ex)
775
+
776
+ async def stop(self) -> None:
777
+ if (
778
+ self.stopping_lock
779
+ and self.stopping_lock.locked()
780
+ or self.run_state
781
+ in (
782
+ FcmPushClientRunState.STOPPING,
783
+ FcmPushClientRunState.STOPPED,
784
+ )
785
+ ):
786
+ return
787
+
788
+ async with self.stopping_lock: # type: ignore[union-attr]
789
+ try:
790
+ self.run_state = FcmPushClientRunState.STOPPING
791
+
792
+ self.do_listen = False
793
+
794
+ for task in self.tasks:
795
+ if not task.done():
796
+ task.cancel()
797
+
798
+ finally:
799
+ self.run_state = FcmPushClientRunState.STOPPED
800
+ self.fcm_thread = None
801
+ self.listen_event_loop = None
802
+
803
+ def is_started(self) -> bool:
804
+ return self.run_state == FcmPushClientRunState.STARTED
805
+
806
+ async def send_message(self, raw_data: bytes, persistent_id: str) -> None:
807
+ """Not implemented, does nothing atm."""
808
+ dms = DataMessageStanza()
809
+ dms.persistent_id = persistent_id
810
+
811
+ # Not supported yet
812
+