roborock-cli 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. roborock_cli/__init__.py +3 -0
  2. roborock_cli/__main__.py +76 -0
  3. roborock_cli/_vendor/VERSION +6 -0
  4. roborock_cli/_vendor/__init__.py +0 -0
  5. roborock_cli/_vendor/roborock/__init__.py +27 -0
  6. roborock_cli/_vendor/roborock/broadcast_protocol.py +114 -0
  7. roborock_cli/_vendor/roborock/callbacks.py +130 -0
  8. roborock_cli/_vendor/roborock/cli.py +1338 -0
  9. roborock_cli/_vendor/roborock/const.py +84 -0
  10. roborock_cli/_vendor/roborock/data/__init__.py +9 -0
  11. roborock_cli/_vendor/roborock/data/b01_q10/__init__.py +2 -0
  12. roborock_cli/_vendor/roborock/data/b01_q10/b01_q10_code_mappings.py +213 -0
  13. roborock_cli/_vendor/roborock/data/b01_q10/b01_q10_containers.py +102 -0
  14. roborock_cli/_vendor/roborock/data/b01_q7/__init__.py +2 -0
  15. roborock_cli/_vendor/roborock/data/b01_q7/b01_q7_code_mappings.py +303 -0
  16. roborock_cli/_vendor/roborock/data/b01_q7/b01_q7_containers.py +302 -0
  17. roborock_cli/_vendor/roborock/data/code_mappings.py +198 -0
  18. roborock_cli/_vendor/roborock/data/containers.py +530 -0
  19. roborock_cli/_vendor/roborock/data/dyad/__init__.py +2 -0
  20. roborock_cli/_vendor/roborock/data/dyad/dyad_code_mappings.py +102 -0
  21. roborock_cli/_vendor/roborock/data/dyad/dyad_containers.py +28 -0
  22. roborock_cli/_vendor/roborock/data/v1/__init__.py +3 -0
  23. roborock_cli/_vendor/roborock/data/v1/v1_clean_modes.py +192 -0
  24. roborock_cli/_vendor/roborock/data/v1/v1_code_mappings.py +644 -0
  25. roborock_cli/_vendor/roborock/data/v1/v1_containers.py +800 -0
  26. roborock_cli/_vendor/roborock/data/zeo/__init__.py +2 -0
  27. roborock_cli/_vendor/roborock/data/zeo/zeo_code_mappings.py +138 -0
  28. roborock_cli/_vendor/roborock/data/zeo/zeo_containers.py +0 -0
  29. roborock_cli/_vendor/roborock/device_features.py +668 -0
  30. roborock_cli/_vendor/roborock/devices/README.md +41 -0
  31. roborock_cli/_vendor/roborock/devices/__init__.py +11 -0
  32. roborock_cli/_vendor/roborock/devices/cache.py +143 -0
  33. roborock_cli/_vendor/roborock/devices/device.py +240 -0
  34. roborock_cli/_vendor/roborock/devices/device_manager.py +269 -0
  35. roborock_cli/_vendor/roborock/devices/file_cache.py +79 -0
  36. roborock_cli/_vendor/roborock/devices/rpc/__init__.py +14 -0
  37. roborock_cli/_vendor/roborock/devices/rpc/a01_channel.py +94 -0
  38. roborock_cli/_vendor/roborock/devices/rpc/b01_q10_channel.py +57 -0
  39. roborock_cli/_vendor/roborock/devices/rpc/b01_q7_channel.py +101 -0
  40. roborock_cli/_vendor/roborock/devices/rpc/v1_channel.py +457 -0
  41. roborock_cli/_vendor/roborock/devices/traits/__init__.py +28 -0
  42. roborock_cli/_vendor/roborock/devices/traits/a01/__init__.py +191 -0
  43. roborock_cli/_vendor/roborock/devices/traits/b01/__init__.py +12 -0
  44. roborock_cli/_vendor/roborock/devices/traits/b01/q10/__init__.py +76 -0
  45. roborock_cli/_vendor/roborock/devices/traits/b01/q10/command.py +32 -0
  46. roborock_cli/_vendor/roborock/devices/traits/b01/q10/common.py +115 -0
  47. roborock_cli/_vendor/roborock/devices/traits/b01/q10/status.py +32 -0
  48. roborock_cli/_vendor/roborock/devices/traits/b01/q10/vacuum.py +81 -0
  49. roborock_cli/_vendor/roborock/devices/traits/b01/q7/__init__.py +136 -0
  50. roborock_cli/_vendor/roborock/devices/traits/b01/q7/clean_summary.py +75 -0
  51. roborock_cli/_vendor/roborock/devices/traits/traits_mixin.py +64 -0
  52. roborock_cli/_vendor/roborock/devices/traits/v1/__init__.py +344 -0
  53. roborock_cli/_vendor/roborock/devices/traits/v1/child_lock.py +29 -0
  54. roborock_cli/_vendor/roborock/devices/traits/v1/clean_summary.py +83 -0
  55. roborock_cli/_vendor/roborock/devices/traits/v1/command.py +38 -0
  56. roborock_cli/_vendor/roborock/devices/traits/v1/common.py +172 -0
  57. roborock_cli/_vendor/roborock/devices/traits/v1/consumeable.py +48 -0
  58. roborock_cli/_vendor/roborock/devices/traits/v1/device_features.py +74 -0
  59. roborock_cli/_vendor/roborock/devices/traits/v1/do_not_disturb.py +41 -0
  60. roborock_cli/_vendor/roborock/devices/traits/v1/dust_collection_mode.py +13 -0
  61. roborock_cli/_vendor/roborock/devices/traits/v1/flow_led_status.py +29 -0
  62. roborock_cli/_vendor/roborock/devices/traits/v1/home.py +285 -0
  63. roborock_cli/_vendor/roborock/devices/traits/v1/led_status.py +43 -0
  64. roborock_cli/_vendor/roborock/devices/traits/v1/map_content.py +83 -0
  65. roborock_cli/_vendor/roborock/devices/traits/v1/maps.py +80 -0
  66. roborock_cli/_vendor/roborock/devices/traits/v1/network_info.py +55 -0
  67. roborock_cli/_vendor/roborock/devices/traits/v1/rooms.py +105 -0
  68. roborock_cli/_vendor/roborock/devices/traits/v1/routines.py +26 -0
  69. roborock_cli/_vendor/roborock/devices/traits/v1/smart_wash_params.py +13 -0
  70. roborock_cli/_vendor/roborock/devices/traits/v1/status.py +101 -0
  71. roborock_cli/_vendor/roborock/devices/traits/v1/valley_electricity_timer.py +44 -0
  72. roborock_cli/_vendor/roborock/devices/traits/v1/volume.py +27 -0
  73. roborock_cli/_vendor/roborock/devices/traits/v1/wash_towel_mode.py +13 -0
  74. roborock_cli/_vendor/roborock/devices/transport/__init__.py +8 -0
  75. roborock_cli/_vendor/roborock/devices/transport/channel.py +32 -0
  76. roborock_cli/_vendor/roborock/devices/transport/local_channel.py +295 -0
  77. roborock_cli/_vendor/roborock/devices/transport/mqtt_channel.py +118 -0
  78. roborock_cli/_vendor/roborock/diagnostics.py +166 -0
  79. roborock_cli/_vendor/roborock/exceptions.py +95 -0
  80. roborock_cli/_vendor/roborock/map/__init__.py +7 -0
  81. roborock_cli/_vendor/roborock/map/map_parser.py +123 -0
  82. roborock_cli/_vendor/roborock/mqtt/__init__.py +10 -0
  83. roborock_cli/_vendor/roborock/mqtt/health_manager.py +60 -0
  84. roborock_cli/_vendor/roborock/mqtt/roborock_session.py +463 -0
  85. roborock_cli/_vendor/roborock/mqtt/session.py +108 -0
  86. roborock_cli/_vendor/roborock/protocol.py +558 -0
  87. roborock_cli/_vendor/roborock/protocols/__init__.py +3 -0
  88. roborock_cli/_vendor/roborock/protocols/a01_protocol.py +74 -0
  89. roborock_cli/_vendor/roborock/protocols/b01_q10_protocol.py +87 -0
  90. roborock_cli/_vendor/roborock/protocols/b01_q7_protocol.py +81 -0
  91. roborock_cli/_vendor/roborock/protocols/v1_protocol.py +271 -0
  92. roborock_cli/_vendor/roborock/py.typed +0 -0
  93. roborock_cli/_vendor/roborock/roborock_message.py +246 -0
  94. roborock_cli/_vendor/roborock/roborock_typing.py +382 -0
  95. roborock_cli/_vendor/roborock/util.py +54 -0
  96. roborock_cli/_vendor/roborock/web_api.py +761 -0
  97. roborock_cli/cli.py +715 -0
  98. roborock_cli/connection.py +202 -0
  99. roborock_cli/helpers.py +71 -0
  100. roborock_cli/server.py +759 -0
  101. roborock_cli/setup_auth.py +92 -0
  102. roborock_cli-0.1.1.dist-info/METADATA +172 -0
  103. roborock_cli-0.1.1.dist-info/RECORD +106 -0
  104. roborock_cli-0.1.1.dist-info/WHEEL +4 -0
  105. roborock_cli-0.1.1.dist-info/entry_points.txt +2 -0
  106. roborock_cli-0.1.1.dist-info/licenses/LICENSE +674 -0
@@ -0,0 +1,108 @@
1
+ """An MQTT session for sending and receiving messages."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Callable
5
+ from dataclasses import dataclass, field
6
+
7
+ from roborock_cli._vendor.roborock.diagnostics import Diagnostics
8
+ from roborock_cli._vendor.roborock.exceptions import RoborockException
9
+ from roborock_cli._vendor.roborock.mqtt.health_manager import HealthManager
10
+
11
+ DEFAULT_TIMEOUT = 30.0
12
+
13
+ SessionUnauthorizedHook = Callable[[], None]
14
+
15
+
16
+ @dataclass
17
+ class MqttParams:
18
+ """MQTT parameters for the connection."""
19
+
20
+ host: str
21
+ """MQTT host to connect to."""
22
+
23
+ port: int
24
+ """MQTT port to connect to."""
25
+
26
+ tls: bool
27
+ """Use TLS for the connection."""
28
+
29
+ username: str
30
+ """MQTT username to use for authentication."""
31
+
32
+ password: str
33
+ """MQTT password to use for authentication."""
34
+
35
+ verify_tls: bool = True
36
+ """Verify the TLS certificate."""
37
+
38
+ timeout: float = DEFAULT_TIMEOUT
39
+ """Timeout for communications with the broker in seconds."""
40
+
41
+ diagnostics: Diagnostics = field(default_factory=Diagnostics)
42
+ """Diagnostics object for tracking MQTT session stats.
43
+
44
+ This defaults to a new Diagnostics object, but the common case is the
45
+ caller will provide their own (e.g., from a DeviceManager) so that the
46
+ shared MQTT session diagnostics are included in the overall diagnostics.
47
+ """
48
+
49
+ unauthorized_hook: SessionUnauthorizedHook | None = None
50
+ """Optional hook invoked when an unauthorized error is received.
51
+
52
+ This may be invoked by the background reconnect logic when an
53
+ unauthorized error is received from the broker. The caller may use
54
+ this hook to refresh credentials or take other actions as needed.
55
+ """
56
+
57
+
58
+ class MqttSession(ABC):
59
+ """An MQTT session for sending and receiving messages."""
60
+
61
+ @property
62
+ @abstractmethod
63
+ def connected(self) -> bool:
64
+ """True if the session is connected to the broker."""
65
+
66
+ @property
67
+ @abstractmethod
68
+ def health_manager(self) -> HealthManager:
69
+ """Return the health manager for the session."""
70
+
71
+ @abstractmethod
72
+ async def subscribe(self, device_id: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
73
+ """Invoke the callback when messages are received on the topic.
74
+
75
+ The returned callable unsubscribes from the topic when called.
76
+ """
77
+
78
+ @abstractmethod
79
+ async def publish(self, topic: str, message: bytes) -> None:
80
+ """Publish a message on the specified topic.
81
+
82
+ This will raise an exception if the message could not be sent.
83
+ """
84
+
85
+ @abstractmethod
86
+ async def restart(self) -> None:
87
+ """Force the session to disconnect and reconnect."""
88
+
89
+ @abstractmethod
90
+ async def close(self) -> None:
91
+ """Cancels the mqtt loop"""
92
+
93
+
94
+ class MqttSessionException(RoborockException):
95
+ """Raised when there is an error communicating with MQTT."""
96
+
97
+
98
+ class MqttSessionUnauthorized(RoborockException):
99
+ """Raised when there is an authorization error communicating with MQTT.
100
+
101
+ This error may be raised in multiple scenarios so there is not a well
102
+ defined behavior for how the caller should behave. The two cases are:
103
+ - Rate limiting is in effect and the caller should retry after some time.
104
+ - The credentials are invalid and the caller needs to obtain new credentials
105
+
106
+ However, it is observed that obtaining new credentials may resolve the
107
+ issue in both cases.
108
+ """
@@ -0,0 +1,558 @@
1
+ from __future__ import annotations
2
+
3
+ import binascii
4
+ import gzip
5
+ import hashlib
6
+ import logging
7
+ from collections.abc import Callable
8
+ from urllib.parse import urlparse
9
+
10
+ from construct import ( # type: ignore
11
+ Bytes,
12
+ Checksum,
13
+ ChecksumError,
14
+ Construct,
15
+ Container,
16
+ GreedyBytes,
17
+ GreedyRange,
18
+ Int16ub,
19
+ Int32ub,
20
+ Optional,
21
+ Peek,
22
+ RawCopy,
23
+ Struct,
24
+ bytestringtype,
25
+ stream_seek,
26
+ stream_tell,
27
+ )
28
+ from Crypto.Cipher import AES
29
+ from Crypto.Util.Padding import pad, unpad
30
+
31
+ from roborock_cli._vendor.roborock.data import RRiot
32
+ from roborock_cli._vendor.roborock.exceptions import RoborockException
33
+ from roborock_cli._vendor.roborock.mqtt.session import MqttParams
34
+ from roborock_cli._vendor.roborock.roborock_message import RoborockMessage
35
+
36
+ _LOGGER = logging.getLogger(__name__)
37
+ SALT = b"TXdfu$jyZ#TZHsg4"
38
+ A01_HASH = "726f626f726f636b2d67a6d6da"
39
+ B01_HASH = "5wwh9ikChRjASpMU8cxg7o1d2E"
40
+ AP_CONFIG = 1
41
+ SOCK_DISCOVERY = 2
42
+
43
+
44
+ def md5hex(message: str) -> str:
45
+ md5 = hashlib.md5()
46
+ md5.update(message.encode())
47
+ return md5.hexdigest()
48
+
49
+
50
+ class Utils:
51
+ """Util class for protocol manipulation."""
52
+
53
+ @staticmethod
54
+ def verify_token(token: bytes):
55
+ """Checks if the given token is of correct type and length."""
56
+ if not isinstance(token, bytes):
57
+ raise TypeError("Token must be bytes")
58
+ if len(token) != 16:
59
+ raise ValueError("Wrong token length")
60
+
61
+ @staticmethod
62
+ def ensure_bytes(msg: bytes | str) -> bytes:
63
+ if isinstance(msg, str):
64
+ return msg.encode()
65
+ return msg
66
+
67
+ @staticmethod
68
+ def encode_timestamp(_timestamp: int) -> bytes:
69
+ hex_value = f"{_timestamp:x}".zfill(8)
70
+ return "".join(list(map(lambda idx: hex_value[idx], [5, 6, 3, 7, 1, 2, 0, 4]))).encode()
71
+
72
+ @staticmethod
73
+ def md5(data: bytes) -> bytes:
74
+ """Calculates a md5 hashsum for the given bytes object."""
75
+ checksum = hashlib.md5() # nosec
76
+ checksum.update(data)
77
+ return checksum.digest()
78
+
79
+ @staticmethod
80
+ def encrypt_ecb(plaintext: bytes, token: bytes) -> bytes:
81
+ """Encrypt plaintext with a given token using ecb mode.
82
+
83
+ :param bytes plaintext: Plaintext (json) to encrypt
84
+ :param bytes token: Token to use
85
+ :return: Encrypted bytes
86
+ """
87
+ if not isinstance(plaintext, bytes):
88
+ raise TypeError("plaintext requires bytes")
89
+ Utils.verify_token(token)
90
+ cipher = AES.new(token, AES.MODE_ECB)
91
+ if plaintext:
92
+ plaintext = pad(plaintext, AES.block_size)
93
+ return cipher.encrypt(plaintext)
94
+ return plaintext
95
+
96
+ @staticmethod
97
+ def decrypt_ecb(ciphertext: bytes, token: bytes) -> bytes:
98
+ """Decrypt ciphertext with a given token using ecb mode.
99
+
100
+ :param bytes ciphertext: Ciphertext to decrypt
101
+ :param bytes token: Token to use
102
+ :return: Decrypted bytes object
103
+ """
104
+ if not isinstance(ciphertext, bytes):
105
+ raise TypeError("ciphertext requires bytes")
106
+ if ciphertext:
107
+ Utils.verify_token(token)
108
+
109
+ aes_key = token
110
+ decipher = AES.new(aes_key, AES.MODE_ECB)
111
+ return unpad(decipher.decrypt(ciphertext), AES.block_size)
112
+ return ciphertext
113
+
114
+ @staticmethod
115
+ def encrypt_cbc(plaintext: bytes, token: bytes) -> bytes:
116
+ """Encrypt plaintext with a given token using cbc mode.
117
+
118
+ This is currently used for testing purposes only.
119
+
120
+ :param bytes plaintext: Plaintext (json) to encrypt
121
+ :param bytes token: Token to use
122
+ :return: Encrypted bytes
123
+ """
124
+ if not isinstance(plaintext, bytes):
125
+ raise TypeError("plaintext requires bytes")
126
+ Utils.verify_token(token)
127
+ iv = bytes(AES.block_size)
128
+ cipher = AES.new(token, AES.MODE_CBC, iv)
129
+ if plaintext:
130
+ plaintext = pad(plaintext, AES.block_size)
131
+ return cipher.encrypt(plaintext)
132
+ return plaintext
133
+
134
+ @staticmethod
135
+ def decrypt_cbc(ciphertext: bytes, token: bytes) -> bytes:
136
+ """Decrypt ciphertext with a given token using cbc mode.
137
+
138
+ :param bytes ciphertext: Ciphertext to decrypt
139
+ :param bytes token: Token to use
140
+ :return: Decrypted bytes object
141
+ """
142
+ if not isinstance(ciphertext, bytes):
143
+ raise TypeError("ciphertext requires bytes")
144
+ if ciphertext:
145
+ Utils.verify_token(token)
146
+
147
+ iv = bytes(AES.block_size)
148
+ decipher = AES.new(token, AES.MODE_CBC, iv)
149
+ return unpad(decipher.decrypt(ciphertext), AES.block_size)
150
+ return ciphertext
151
+
152
+ @staticmethod
153
+ def _l01_key(local_key: str, timestamp: int) -> bytes:
154
+ """Derive key for L01 protocol."""
155
+ hash_input = Utils.encode_timestamp(timestamp) + Utils.ensure_bytes(local_key) + SALT
156
+ return hashlib.sha256(hash_input).digest()
157
+
158
+ @staticmethod
159
+ def _l01_iv(timestamp: int, nonce: int, sequence: int) -> bytes:
160
+ """Derive IV for L01 protocol."""
161
+ digest_input = sequence.to_bytes(4, "big") + nonce.to_bytes(4, "big") + timestamp.to_bytes(4, "big")
162
+ digest = hashlib.sha256(digest_input).digest()
163
+ return digest[:12]
164
+
165
+ @staticmethod
166
+ def _l01_aad(timestamp: int, nonce: int, sequence: int, connect_nonce: int, ack_nonce: int | None = None) -> bytes:
167
+ """Derive AAD for L01 protocol."""
168
+ return (
169
+ sequence.to_bytes(4, "big")
170
+ + connect_nonce.to_bytes(4, "big")
171
+ + (ack_nonce.to_bytes(4, "big") if ack_nonce is not None else b"")
172
+ + nonce.to_bytes(4, "big")
173
+ + timestamp.to_bytes(4, "big")
174
+ )
175
+
176
+ @staticmethod
177
+ def encrypt_gcm_l01(
178
+ plaintext: bytes,
179
+ local_key: str,
180
+ timestamp: int,
181
+ sequence: int,
182
+ nonce: int,
183
+ connect_nonce: int,
184
+ ack_nonce: int | None = None,
185
+ ) -> bytes:
186
+ """Encrypt plaintext for L01 protocol using AES-256-GCM."""
187
+ if not isinstance(plaintext, bytes):
188
+ raise TypeError("plaintext requires bytes")
189
+
190
+ key = Utils._l01_key(local_key, timestamp)
191
+ iv = Utils._l01_iv(timestamp, nonce, sequence)
192
+ aad = Utils._l01_aad(timestamp, nonce, sequence, connect_nonce, ack_nonce)
193
+
194
+ cipher = AES.new(key, AES.MODE_GCM, nonce=iv)
195
+ cipher.update(aad)
196
+ ciphertext, tag = cipher.encrypt_and_digest(plaintext)
197
+
198
+ return ciphertext + tag
199
+
200
+ @staticmethod
201
+ def decrypt_gcm_l01(
202
+ payload: bytes,
203
+ local_key: str,
204
+ timestamp: int,
205
+ sequence: int,
206
+ nonce: int,
207
+ connect_nonce: int,
208
+ ack_nonce: int,
209
+ ) -> bytes:
210
+ """Decrypt payload for L01 protocol using AES-256-GCM."""
211
+ if not isinstance(payload, bytes):
212
+ raise TypeError("payload requires bytes")
213
+
214
+ key = Utils._l01_key(local_key, timestamp)
215
+ iv = Utils._l01_iv(timestamp, nonce, sequence)
216
+ aad = Utils._l01_aad(timestamp, nonce, sequence, connect_nonce, ack_nonce)
217
+
218
+ if len(payload) < 16:
219
+ raise ValueError("Invalid payload length for GCM decryption")
220
+
221
+ tag = payload[-16:]
222
+ ciphertext = payload[:-16]
223
+
224
+ cipher = AES.new(key, AES.MODE_GCM, nonce=iv)
225
+ cipher.update(aad)
226
+
227
+ try:
228
+ return cipher.decrypt_and_verify(ciphertext, tag)
229
+ except ValueError as e:
230
+ raise RoborockException("GCM tag verification failed") from e
231
+
232
+ @staticmethod
233
+ def crc(data: bytes) -> int:
234
+ """Gather bytes for checksum calculation."""
235
+ return binascii.crc32(data)
236
+
237
+ @staticmethod
238
+ def decompress(compressed_data: bytes):
239
+ """Decompress data using gzip."""
240
+ return gzip.decompress(compressed_data)
241
+
242
+
243
+ class EncryptionAdapter(Construct):
244
+ """Adapter to handle communication encryption."""
245
+
246
+ def __init__(self, token_func: Callable):
247
+ super().__init__()
248
+ self.token_func = token_func
249
+
250
+ def _parse(self, stream, context, path):
251
+ subcon1 = Optional(Int16ub)
252
+ length = subcon1.parse_stream(stream, **context)
253
+ if not length:
254
+ if length == 0:
255
+ subcon1.parse_stream(stream, **context) # seek 2
256
+ return None
257
+ subcon2 = Bytes(length)
258
+ obj = subcon2.parse_stream(stream, **context)
259
+ return self._decode(obj, context, path)
260
+
261
+ def _build(self, obj, stream, context, path):
262
+ if obj is not None:
263
+ obj2 = self._encode(obj, context, path)
264
+ subcon1 = Int16ub
265
+ length = len(obj2)
266
+ subcon1.build_stream(length, stream, **context)
267
+ subcon2 = Bytes(length)
268
+ subcon2.build_stream(obj2, stream, **context)
269
+ return obj
270
+
271
+ def _encode(self, obj, context, _):
272
+ """Encrypt the given payload with the token stored in the context.
273
+
274
+ :param obj: JSON object to encrypt
275
+ """
276
+ if context.version == b"A01":
277
+ iv = md5hex(format(context.random, "08x") + A01_HASH)[8:24]
278
+ decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8"))
279
+ return decipher.encrypt(obj)
280
+ elif context.version == b"B01":
281
+ iv = md5hex(f"{context.random:08x}" + B01_HASH)[9:25]
282
+ decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8"))
283
+ return decipher.encrypt(pad(obj, AES.block_size))
284
+ elif context.version == b"L01":
285
+ return Utils.encrypt_gcm_l01(
286
+ plaintext=obj,
287
+ local_key=context.search("local_key"),
288
+ timestamp=context.timestamp,
289
+ sequence=context.seq,
290
+ nonce=context.random,
291
+ connect_nonce=context.search("connect_nonce"),
292
+ ack_nonce=context.search("ack_nonce"),
293
+ )
294
+ token = self.token_func(context)
295
+ encrypted = Utils.encrypt_ecb(obj, token)
296
+ return encrypted
297
+
298
+ def _decode(self, obj, context, _):
299
+ """Decrypts the given payload with the token stored in the context."""
300
+ if context.version == b"A01":
301
+ iv = md5hex(format(context.random, "08x") + A01_HASH)[8:24]
302
+ decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8"))
303
+ return decipher.decrypt(obj)
304
+ elif context.version == b"B01":
305
+ iv = md5hex(f"{context.random:08x}" + B01_HASH)[9:25]
306
+ decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8"))
307
+ return unpad(decipher.decrypt(obj), AES.block_size)
308
+ elif context.version == b"L01":
309
+ return Utils.decrypt_gcm_l01(
310
+ payload=obj,
311
+ local_key=context.search("local_key"),
312
+ timestamp=context.timestamp,
313
+ sequence=context.seq,
314
+ nonce=context.random,
315
+ connect_nonce=context.search("connect_nonce"),
316
+ ack_nonce=context.search("ack_nonce"),
317
+ )
318
+ token = self.token_func(context)
319
+ decrypted = Utils.decrypt_ecb(obj, token)
320
+ return decrypted
321
+
322
+
323
+ class OptionalChecksum(Checksum):
324
+ def _parse(self, stream, context, path):
325
+ if not context.message.value.payload:
326
+ return
327
+ hash1 = self.checksumfield.parse_stream(stream, **context)
328
+ hash2 = self.hashfunc(self.bytesfunc(context))
329
+ if hash1 != hash2:
330
+ raise ChecksumError(
331
+ f"wrong checksum, read {hash1 if not isinstance(hash1, bytestringtype) else binascii.hexlify(hash1)}, "
332
+ f"computed {hash2 if not isinstance(hash2, bytestringtype) else binascii.hexlify(hash2)}",
333
+ path=path,
334
+ )
335
+ return hash1
336
+
337
+
338
+ class PrefixedStruct(Struct):
339
+ def _parse(self, stream, context, path):
340
+ subcon1 = Peek(Optional(Bytes(3)))
341
+ peek_version = subcon1.parse_stream(stream, **context)
342
+
343
+ valid_versions = (b"1.0", b"A01", b"B01", b"L01")
344
+ if peek_version not in valid_versions:
345
+ # Current stream position does not start with a valid version.
346
+ # Scan forward to find one.
347
+ current_pos = stream_tell(stream, path)
348
+ # Read remaining data to find a valid header
349
+ data = stream.read()
350
+
351
+ if not data:
352
+ # EOF reached, let the parser fail naturally without logging
353
+ stream_seek(stream, current_pos, 0, path)
354
+ return super()._parse(stream, context, path)
355
+
356
+ start_index = -1
357
+ # Find the earliest occurrence of any valid version in a single pass
358
+ for i in range(len(data) - 2):
359
+ if data[i : i + 3] in valid_versions:
360
+ start_index = i
361
+ break
362
+
363
+ if start_index != -1:
364
+ # Found a valid version header at `start_index`.
365
+ # Seek to that position (original_pos + index).
366
+ if start_index != 4:
367
+ # 4 is the typical/expected amount we prune off,
368
+ # therefore, we only want a debug if we have a different length.
369
+ _LOGGER.debug("Stripping %d bytes of invalid data from stream", start_index)
370
+ stream_seek(stream, current_pos + start_index, 0, path)
371
+ else:
372
+ _LOGGER.debug("No valid version header found in stream, continuing anyways...")
373
+ # Seek back to the original position to avoid parsing at EOF
374
+ stream_seek(stream, current_pos, 0, path)
375
+
376
+ return super()._parse(stream, context, path)
377
+
378
+ def _build(self, obj, stream, context, path):
379
+ prefixed = context.search("prefixed")
380
+ if not prefixed:
381
+ return super()._build(obj, stream, context, path)
382
+ offset = stream_tell(stream, path)
383
+ stream_seek(stream, offset + 4, 0, path)
384
+ super()._build(obj, stream, context, path)
385
+ new_offset = stream_tell(stream, path)
386
+ subcon1 = Bytes(4)
387
+ stream_seek(stream, offset, 0, path)
388
+ subcon1.build_stream(new_offset - offset - subcon1.sizeof(**context), stream, **context)
389
+ stream_seek(stream, new_offset + 4, 0, path)
390
+ return obj
391
+
392
+
393
+ _Message = RawCopy(
394
+ Struct(
395
+ "version" / Bytes(3),
396
+ "seq" / Int32ub,
397
+ "random" / Int32ub,
398
+ "timestamp" / Int32ub,
399
+ "protocol" / Int16ub,
400
+ "payload"
401
+ / EncryptionAdapter(
402
+ lambda ctx: Utils.md5(
403
+ Utils.encode_timestamp(ctx.timestamp) + Utils.ensure_bytes(ctx.search("local_key")) + SALT
404
+ ),
405
+ ),
406
+ )
407
+ )
408
+
409
+ _Messages = Struct(
410
+ "messages"
411
+ / GreedyRange(
412
+ PrefixedStruct(
413
+ "message" / _Message,
414
+ "checksum" / OptionalChecksum(Optional(Int32ub), Utils.crc, lambda ctx: ctx.message.data),
415
+ )
416
+ ),
417
+ "remaining" / Optional(GreedyBytes),
418
+ )
419
+
420
+
421
+ class _Parser:
422
+ def __init__(self, con: Construct, required_local_key: bool):
423
+ self.con = con
424
+ self.required_local_key = required_local_key
425
+
426
+ def parse(
427
+ self, data: bytes, local_key: str | None = None, connect_nonce: int | None = None, ack_nonce: int | None = None
428
+ ) -> tuple[list[RoborockMessage], bytes]:
429
+ if self.required_local_key and local_key is None:
430
+ raise RoborockException("Local key is required")
431
+ parsed = self.con.parse(data, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce)
432
+ parsed_messages = [Container({"message": parsed.message})] if parsed.get("message") else parsed.messages
433
+ messages = []
434
+ for message in parsed_messages:
435
+ messages.append(
436
+ RoborockMessage(
437
+ version=message.message.value.version,
438
+ seq=message.message.value.get("seq"),
439
+ random=message.message.value.get("random"),
440
+ timestamp=message.message.value.get("timestamp"),
441
+ protocol=message.message.value.get("protocol"),
442
+ payload=message.message.value.payload,
443
+ )
444
+ )
445
+ remaining = parsed.get("remaining") or b""
446
+ return messages, remaining
447
+
448
+ def build(
449
+ self,
450
+ roborock_messages: list[RoborockMessage] | RoborockMessage,
451
+ local_key: str,
452
+ prefixed: bool = True,
453
+ connect_nonce: int | None = None,
454
+ ack_nonce: int | None = None,
455
+ ) -> bytes:
456
+ if isinstance(roborock_messages, RoborockMessage):
457
+ roborock_messages = [roborock_messages]
458
+ messages = []
459
+ for roborock_message in roborock_messages:
460
+ messages.append(
461
+ {
462
+ "message": {
463
+ "value": {
464
+ "version": roborock_message.version,
465
+ "seq": roborock_message.seq,
466
+ "random": roborock_message.random,
467
+ "timestamp": roborock_message.timestamp,
468
+ "protocol": roborock_message.protocol,
469
+ "payload": roborock_message.payload,
470
+ }
471
+ },
472
+ }
473
+ )
474
+ return self.con.build(
475
+ {"messages": [message for message in messages], "remaining": b""},
476
+ local_key=local_key,
477
+ prefixed=prefixed,
478
+ connect_nonce=connect_nonce,
479
+ ack_nonce=ack_nonce,
480
+ )
481
+
482
+
483
+ MessageParser: _Parser = _Parser(_Messages, True)
484
+
485
+
486
+ def create_mqtt_params(rriot: RRiot) -> MqttParams:
487
+ """Return the MQTT parameters for this user."""
488
+ url = urlparse(rriot.r.m)
489
+ if not isinstance(url.hostname, str):
490
+ raise RoborockException(f"Url parsing '{rriot.r.m}' returned an invalid hostname")
491
+ if not url.port:
492
+ raise RoborockException(f"Url parsing '{rriot.r.m}' returned an invalid port")
493
+ hashed_user = md5hex(rriot.u + ":" + rriot.k)[2:10]
494
+ hashed_password = md5hex(rriot.s + ":" + rriot.k)[16:]
495
+ return MqttParams(
496
+ host=str(url.hostname),
497
+ port=url.port,
498
+ tls=(url.scheme == "ssl"),
499
+ username=hashed_user,
500
+ password=hashed_password,
501
+ )
502
+
503
+
504
+ Decoder = Callable[[bytes], list[RoborockMessage]]
505
+ Encoder = Callable[[RoborockMessage], bytes]
506
+
507
+
508
+ def create_mqtt_decoder(local_key: str) -> Decoder:
509
+ """Create a decoder for MQTT messages."""
510
+
511
+ def decode(data: bytes) -> list[RoborockMessage]:
512
+ """Parse the given data into Roborock messages."""
513
+ messages, _ = MessageParser.parse(data, local_key)
514
+ return messages
515
+
516
+ return decode
517
+
518
+
519
+ def create_mqtt_encoder(local_key: str) -> Encoder:
520
+ """Create an encoder for MQTT messages."""
521
+
522
+ def encode(messages: RoborockMessage) -> bytes:
523
+ """Build the given Roborock messages into a byte string."""
524
+ return MessageParser.build(messages, local_key, prefixed=False)
525
+
526
+ return encode
527
+
528
+
529
+ def create_local_decoder(local_key: str, connect_nonce: int | None = None, ack_nonce: int | None = None) -> Decoder:
530
+ """Create a decoder for local API messages."""
531
+
532
+ # This buffer is used to accumulate bytes until a complete message can be parsed.
533
+ # It is defined outside the decode function to maintain state across calls.
534
+ buffer: bytes = b""
535
+
536
+ def decode(bytes_data: bytes) -> list[RoborockMessage]:
537
+ """Parse the given data into Roborock messages."""
538
+ nonlocal buffer
539
+ buffer += bytes_data
540
+ parsed_messages, remaining = MessageParser.parse(
541
+ buffer, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce
542
+ )
543
+ if remaining:
544
+ _LOGGER.debug("Found %d extra bytes: %s", len(remaining), remaining)
545
+ buffer = remaining
546
+ return parsed_messages
547
+
548
+ return decode
549
+
550
+
551
+ def create_local_encoder(local_key: str, connect_nonce: int | None = None, ack_nonce: int | None = None) -> Encoder:
552
+ """Create an encoder for local API messages."""
553
+
554
+ def encode(message: RoborockMessage) -> bytes:
555
+ """Called when data is sent to the transport."""
556
+ return MessageParser.build(message, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce)
557
+
558
+ return encode
@@ -0,0 +1,3 @@
1
+ """Protocols for communicating with Roborock devices."""
2
+
3
+ __all__: list[str] = []