roborock-cli 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roborock_cli/__init__.py +3 -0
- roborock_cli/__main__.py +76 -0
- roborock_cli/_vendor/VERSION +6 -0
- roborock_cli/_vendor/__init__.py +0 -0
- roborock_cli/_vendor/roborock/__init__.py +27 -0
- roborock_cli/_vendor/roborock/broadcast_protocol.py +114 -0
- roborock_cli/_vendor/roborock/callbacks.py +130 -0
- roborock_cli/_vendor/roborock/cli.py +1338 -0
- roborock_cli/_vendor/roborock/const.py +84 -0
- roborock_cli/_vendor/roborock/data/__init__.py +9 -0
- roborock_cli/_vendor/roborock/data/b01_q10/__init__.py +2 -0
- roborock_cli/_vendor/roborock/data/b01_q10/b01_q10_code_mappings.py +213 -0
- roborock_cli/_vendor/roborock/data/b01_q10/b01_q10_containers.py +102 -0
- roborock_cli/_vendor/roborock/data/b01_q7/__init__.py +2 -0
- roborock_cli/_vendor/roborock/data/b01_q7/b01_q7_code_mappings.py +303 -0
- roborock_cli/_vendor/roborock/data/b01_q7/b01_q7_containers.py +302 -0
- roborock_cli/_vendor/roborock/data/code_mappings.py +198 -0
- roborock_cli/_vendor/roborock/data/containers.py +530 -0
- roborock_cli/_vendor/roborock/data/dyad/__init__.py +2 -0
- roborock_cli/_vendor/roborock/data/dyad/dyad_code_mappings.py +102 -0
- roborock_cli/_vendor/roborock/data/dyad/dyad_containers.py +28 -0
- roborock_cli/_vendor/roborock/data/v1/__init__.py +3 -0
- roborock_cli/_vendor/roborock/data/v1/v1_clean_modes.py +192 -0
- roborock_cli/_vendor/roborock/data/v1/v1_code_mappings.py +644 -0
- roborock_cli/_vendor/roborock/data/v1/v1_containers.py +800 -0
- roborock_cli/_vendor/roborock/data/zeo/__init__.py +2 -0
- roborock_cli/_vendor/roborock/data/zeo/zeo_code_mappings.py +138 -0
- roborock_cli/_vendor/roborock/data/zeo/zeo_containers.py +0 -0
- roborock_cli/_vendor/roborock/device_features.py +668 -0
- roborock_cli/_vendor/roborock/devices/README.md +41 -0
- roborock_cli/_vendor/roborock/devices/__init__.py +11 -0
- roborock_cli/_vendor/roborock/devices/cache.py +143 -0
- roborock_cli/_vendor/roborock/devices/device.py +240 -0
- roborock_cli/_vendor/roborock/devices/device_manager.py +269 -0
- roborock_cli/_vendor/roborock/devices/file_cache.py +79 -0
- roborock_cli/_vendor/roborock/devices/rpc/__init__.py +14 -0
- roborock_cli/_vendor/roborock/devices/rpc/a01_channel.py +94 -0
- roborock_cli/_vendor/roborock/devices/rpc/b01_q10_channel.py +57 -0
- roborock_cli/_vendor/roborock/devices/rpc/b01_q7_channel.py +101 -0
- roborock_cli/_vendor/roborock/devices/rpc/v1_channel.py +457 -0
- roborock_cli/_vendor/roborock/devices/traits/__init__.py +28 -0
- roborock_cli/_vendor/roborock/devices/traits/a01/__init__.py +191 -0
- roborock_cli/_vendor/roborock/devices/traits/b01/__init__.py +12 -0
- roborock_cli/_vendor/roborock/devices/traits/b01/q10/__init__.py +76 -0
- roborock_cli/_vendor/roborock/devices/traits/b01/q10/command.py +32 -0
- roborock_cli/_vendor/roborock/devices/traits/b01/q10/common.py +115 -0
- roborock_cli/_vendor/roborock/devices/traits/b01/q10/status.py +32 -0
- roborock_cli/_vendor/roborock/devices/traits/b01/q10/vacuum.py +81 -0
- roborock_cli/_vendor/roborock/devices/traits/b01/q7/__init__.py +136 -0
- roborock_cli/_vendor/roborock/devices/traits/b01/q7/clean_summary.py +75 -0
- roborock_cli/_vendor/roborock/devices/traits/traits_mixin.py +64 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/__init__.py +344 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/child_lock.py +29 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/clean_summary.py +83 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/command.py +38 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/common.py +172 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/consumeable.py +48 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/device_features.py +74 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/do_not_disturb.py +41 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/dust_collection_mode.py +13 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/flow_led_status.py +29 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/home.py +285 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/led_status.py +43 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/map_content.py +83 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/maps.py +80 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/network_info.py +55 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/rooms.py +105 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/routines.py +26 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/smart_wash_params.py +13 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/status.py +101 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/valley_electricity_timer.py +44 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/volume.py +27 -0
- roborock_cli/_vendor/roborock/devices/traits/v1/wash_towel_mode.py +13 -0
- roborock_cli/_vendor/roborock/devices/transport/__init__.py +8 -0
- roborock_cli/_vendor/roborock/devices/transport/channel.py +32 -0
- roborock_cli/_vendor/roborock/devices/transport/local_channel.py +295 -0
- roborock_cli/_vendor/roborock/devices/transport/mqtt_channel.py +118 -0
- roborock_cli/_vendor/roborock/diagnostics.py +166 -0
- roborock_cli/_vendor/roborock/exceptions.py +95 -0
- roborock_cli/_vendor/roborock/map/__init__.py +7 -0
- roborock_cli/_vendor/roborock/map/map_parser.py +123 -0
- roborock_cli/_vendor/roborock/mqtt/__init__.py +10 -0
- roborock_cli/_vendor/roborock/mqtt/health_manager.py +60 -0
- roborock_cli/_vendor/roborock/mqtt/roborock_session.py +463 -0
- roborock_cli/_vendor/roborock/mqtt/session.py +108 -0
- roborock_cli/_vendor/roborock/protocol.py +558 -0
- roborock_cli/_vendor/roborock/protocols/__init__.py +3 -0
- roborock_cli/_vendor/roborock/protocols/a01_protocol.py +74 -0
- roborock_cli/_vendor/roborock/protocols/b01_q10_protocol.py +87 -0
- roborock_cli/_vendor/roborock/protocols/b01_q7_protocol.py +81 -0
- roborock_cli/_vendor/roborock/protocols/v1_protocol.py +271 -0
- roborock_cli/_vendor/roborock/py.typed +0 -0
- roborock_cli/_vendor/roborock/roborock_message.py +246 -0
- roborock_cli/_vendor/roborock/roborock_typing.py +382 -0
- roborock_cli/_vendor/roborock/util.py +54 -0
- roborock_cli/_vendor/roborock/web_api.py +761 -0
- roborock_cli/cli.py +715 -0
- roborock_cli/connection.py +202 -0
- roborock_cli/helpers.py +71 -0
- roborock_cli/server.py +759 -0
- roborock_cli/setup_auth.py +92 -0
- roborock_cli-0.1.1.dist-info/METADATA +172 -0
- roborock_cli-0.1.1.dist-info/RECORD +106 -0
- roborock_cli-0.1.1.dist-info/WHEEL +4 -0
- roborock_cli-0.1.1.dist-info/entry_points.txt +2 -0
- roborock_cli-0.1.1.dist-info/licenses/LICENSE +674 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""An MQTT session for sending and receiving messages."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
from roborock_cli._vendor.roborock.diagnostics import Diagnostics
|
|
8
|
+
from roborock_cli._vendor.roborock.exceptions import RoborockException
|
|
9
|
+
from roborock_cli._vendor.roborock.mqtt.health_manager import HealthManager
|
|
10
|
+
|
|
11
|
+
DEFAULT_TIMEOUT = 30.0
|
|
12
|
+
|
|
13
|
+
SessionUnauthorizedHook = Callable[[], None]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class MqttParams:
|
|
18
|
+
"""MQTT parameters for the connection."""
|
|
19
|
+
|
|
20
|
+
host: str
|
|
21
|
+
"""MQTT host to connect to."""
|
|
22
|
+
|
|
23
|
+
port: int
|
|
24
|
+
"""MQTT port to connect to."""
|
|
25
|
+
|
|
26
|
+
tls: bool
|
|
27
|
+
"""Use TLS for the connection."""
|
|
28
|
+
|
|
29
|
+
username: str
|
|
30
|
+
"""MQTT username to use for authentication."""
|
|
31
|
+
|
|
32
|
+
password: str
|
|
33
|
+
"""MQTT password to use for authentication."""
|
|
34
|
+
|
|
35
|
+
verify_tls: bool = True
|
|
36
|
+
"""Verify the TLS certificate."""
|
|
37
|
+
|
|
38
|
+
timeout: float = DEFAULT_TIMEOUT
|
|
39
|
+
"""Timeout for communications with the broker in seconds."""
|
|
40
|
+
|
|
41
|
+
diagnostics: Diagnostics = field(default_factory=Diagnostics)
|
|
42
|
+
"""Diagnostics object for tracking MQTT session stats.
|
|
43
|
+
|
|
44
|
+
This defaults to a new Diagnostics object, but the common case is the
|
|
45
|
+
caller will provide their own (e.g., from a DeviceManager) so that the
|
|
46
|
+
shared MQTT session diagnostics are included in the overall diagnostics.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
unauthorized_hook: SessionUnauthorizedHook | None = None
|
|
50
|
+
"""Optional hook invoked when an unauthorized error is received.
|
|
51
|
+
|
|
52
|
+
This may be invoked by the background reconnect logic when an
|
|
53
|
+
unauthorized error is received from the broker. The caller may use
|
|
54
|
+
this hook to refresh credentials or take other actions as needed.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class MqttSession(ABC):
|
|
59
|
+
"""An MQTT session for sending and receiving messages."""
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def connected(self) -> bool:
|
|
64
|
+
"""True if the session is connected to the broker."""
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def health_manager(self) -> HealthManager:
|
|
69
|
+
"""Return the health manager for the session."""
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
async def subscribe(self, device_id: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
|
|
73
|
+
"""Invoke the callback when messages are received on the topic.
|
|
74
|
+
|
|
75
|
+
The returned callable unsubscribes from the topic when called.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
async def publish(self, topic: str, message: bytes) -> None:
|
|
80
|
+
"""Publish a message on the specified topic.
|
|
81
|
+
|
|
82
|
+
This will raise an exception if the message could not be sent.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
@abstractmethod
|
|
86
|
+
async def restart(self) -> None:
|
|
87
|
+
"""Force the session to disconnect and reconnect."""
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
async def close(self) -> None:
|
|
91
|
+
"""Cancels the mqtt loop"""
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class MqttSessionException(RoborockException):
|
|
95
|
+
"""Raised when there is an error communicating with MQTT."""
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class MqttSessionUnauthorized(RoborockException):
|
|
99
|
+
"""Raised when there is an authorization error communicating with MQTT.
|
|
100
|
+
|
|
101
|
+
This error may be raised in multiple scenarios so there is not a well
|
|
102
|
+
defined behavior for how the caller should behave. The two cases are:
|
|
103
|
+
- Rate limiting is in effect and the caller should retry after some time.
|
|
104
|
+
- The credentials are invalid and the caller needs to obtain new credentials
|
|
105
|
+
|
|
106
|
+
However, it is observed that obtaining new credentials may resolve the
|
|
107
|
+
issue in both cases.
|
|
108
|
+
"""
|
|
@@ -0,0 +1,558 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import binascii
|
|
4
|
+
import gzip
|
|
5
|
+
import hashlib
|
|
6
|
+
import logging
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from urllib.parse import urlparse
|
|
9
|
+
|
|
10
|
+
from construct import ( # type: ignore
|
|
11
|
+
Bytes,
|
|
12
|
+
Checksum,
|
|
13
|
+
ChecksumError,
|
|
14
|
+
Construct,
|
|
15
|
+
Container,
|
|
16
|
+
GreedyBytes,
|
|
17
|
+
GreedyRange,
|
|
18
|
+
Int16ub,
|
|
19
|
+
Int32ub,
|
|
20
|
+
Optional,
|
|
21
|
+
Peek,
|
|
22
|
+
RawCopy,
|
|
23
|
+
Struct,
|
|
24
|
+
bytestringtype,
|
|
25
|
+
stream_seek,
|
|
26
|
+
stream_tell,
|
|
27
|
+
)
|
|
28
|
+
from Crypto.Cipher import AES
|
|
29
|
+
from Crypto.Util.Padding import pad, unpad
|
|
30
|
+
|
|
31
|
+
from roborock_cli._vendor.roborock.data import RRiot
|
|
32
|
+
from roborock_cli._vendor.roborock.exceptions import RoborockException
|
|
33
|
+
from roborock_cli._vendor.roborock.mqtt.session import MqttParams
|
|
34
|
+
from roborock_cli._vendor.roborock.roborock_message import RoborockMessage
|
|
35
|
+
|
|
36
|
+
_LOGGER = logging.getLogger(__name__)
|
|
37
|
+
SALT = b"TXdfu$jyZ#TZHsg4"
|
|
38
|
+
A01_HASH = "726f626f726f636b2d67a6d6da"
|
|
39
|
+
B01_HASH = "5wwh9ikChRjASpMU8cxg7o1d2E"
|
|
40
|
+
AP_CONFIG = 1
|
|
41
|
+
SOCK_DISCOVERY = 2
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def md5hex(message: str) -> str:
|
|
45
|
+
md5 = hashlib.md5()
|
|
46
|
+
md5.update(message.encode())
|
|
47
|
+
return md5.hexdigest()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Utils:
|
|
51
|
+
"""Util class for protocol manipulation."""
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def verify_token(token: bytes):
|
|
55
|
+
"""Checks if the given token is of correct type and length."""
|
|
56
|
+
if not isinstance(token, bytes):
|
|
57
|
+
raise TypeError("Token must be bytes")
|
|
58
|
+
if len(token) != 16:
|
|
59
|
+
raise ValueError("Wrong token length")
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def ensure_bytes(msg: bytes | str) -> bytes:
|
|
63
|
+
if isinstance(msg, str):
|
|
64
|
+
return msg.encode()
|
|
65
|
+
return msg
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def encode_timestamp(_timestamp: int) -> bytes:
|
|
69
|
+
hex_value = f"{_timestamp:x}".zfill(8)
|
|
70
|
+
return "".join(list(map(lambda idx: hex_value[idx], [5, 6, 3, 7, 1, 2, 0, 4]))).encode()
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def md5(data: bytes) -> bytes:
|
|
74
|
+
"""Calculates a md5 hashsum for the given bytes object."""
|
|
75
|
+
checksum = hashlib.md5() # nosec
|
|
76
|
+
checksum.update(data)
|
|
77
|
+
return checksum.digest()
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def encrypt_ecb(plaintext: bytes, token: bytes) -> bytes:
|
|
81
|
+
"""Encrypt plaintext with a given token using ecb mode.
|
|
82
|
+
|
|
83
|
+
:param bytes plaintext: Plaintext (json) to encrypt
|
|
84
|
+
:param bytes token: Token to use
|
|
85
|
+
:return: Encrypted bytes
|
|
86
|
+
"""
|
|
87
|
+
if not isinstance(plaintext, bytes):
|
|
88
|
+
raise TypeError("plaintext requires bytes")
|
|
89
|
+
Utils.verify_token(token)
|
|
90
|
+
cipher = AES.new(token, AES.MODE_ECB)
|
|
91
|
+
if plaintext:
|
|
92
|
+
plaintext = pad(plaintext, AES.block_size)
|
|
93
|
+
return cipher.encrypt(plaintext)
|
|
94
|
+
return plaintext
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def decrypt_ecb(ciphertext: bytes, token: bytes) -> bytes:
|
|
98
|
+
"""Decrypt ciphertext with a given token using ecb mode.
|
|
99
|
+
|
|
100
|
+
:param bytes ciphertext: Ciphertext to decrypt
|
|
101
|
+
:param bytes token: Token to use
|
|
102
|
+
:return: Decrypted bytes object
|
|
103
|
+
"""
|
|
104
|
+
if not isinstance(ciphertext, bytes):
|
|
105
|
+
raise TypeError("ciphertext requires bytes")
|
|
106
|
+
if ciphertext:
|
|
107
|
+
Utils.verify_token(token)
|
|
108
|
+
|
|
109
|
+
aes_key = token
|
|
110
|
+
decipher = AES.new(aes_key, AES.MODE_ECB)
|
|
111
|
+
return unpad(decipher.decrypt(ciphertext), AES.block_size)
|
|
112
|
+
return ciphertext
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
def encrypt_cbc(plaintext: bytes, token: bytes) -> bytes:
|
|
116
|
+
"""Encrypt plaintext with a given token using cbc mode.
|
|
117
|
+
|
|
118
|
+
This is currently used for testing purposes only.
|
|
119
|
+
|
|
120
|
+
:param bytes plaintext: Plaintext (json) to encrypt
|
|
121
|
+
:param bytes token: Token to use
|
|
122
|
+
:return: Encrypted bytes
|
|
123
|
+
"""
|
|
124
|
+
if not isinstance(plaintext, bytes):
|
|
125
|
+
raise TypeError("plaintext requires bytes")
|
|
126
|
+
Utils.verify_token(token)
|
|
127
|
+
iv = bytes(AES.block_size)
|
|
128
|
+
cipher = AES.new(token, AES.MODE_CBC, iv)
|
|
129
|
+
if plaintext:
|
|
130
|
+
plaintext = pad(plaintext, AES.block_size)
|
|
131
|
+
return cipher.encrypt(plaintext)
|
|
132
|
+
return plaintext
|
|
133
|
+
|
|
134
|
+
@staticmethod
|
|
135
|
+
def decrypt_cbc(ciphertext: bytes, token: bytes) -> bytes:
|
|
136
|
+
"""Decrypt ciphertext with a given token using cbc mode.
|
|
137
|
+
|
|
138
|
+
:param bytes ciphertext: Ciphertext to decrypt
|
|
139
|
+
:param bytes token: Token to use
|
|
140
|
+
:return: Decrypted bytes object
|
|
141
|
+
"""
|
|
142
|
+
if not isinstance(ciphertext, bytes):
|
|
143
|
+
raise TypeError("ciphertext requires bytes")
|
|
144
|
+
if ciphertext:
|
|
145
|
+
Utils.verify_token(token)
|
|
146
|
+
|
|
147
|
+
iv = bytes(AES.block_size)
|
|
148
|
+
decipher = AES.new(token, AES.MODE_CBC, iv)
|
|
149
|
+
return unpad(decipher.decrypt(ciphertext), AES.block_size)
|
|
150
|
+
return ciphertext
|
|
151
|
+
|
|
152
|
+
@staticmethod
|
|
153
|
+
def _l01_key(local_key: str, timestamp: int) -> bytes:
|
|
154
|
+
"""Derive key for L01 protocol."""
|
|
155
|
+
hash_input = Utils.encode_timestamp(timestamp) + Utils.ensure_bytes(local_key) + SALT
|
|
156
|
+
return hashlib.sha256(hash_input).digest()
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def _l01_iv(timestamp: int, nonce: int, sequence: int) -> bytes:
|
|
160
|
+
"""Derive IV for L01 protocol."""
|
|
161
|
+
digest_input = sequence.to_bytes(4, "big") + nonce.to_bytes(4, "big") + timestamp.to_bytes(4, "big")
|
|
162
|
+
digest = hashlib.sha256(digest_input).digest()
|
|
163
|
+
return digest[:12]
|
|
164
|
+
|
|
165
|
+
@staticmethod
|
|
166
|
+
def _l01_aad(timestamp: int, nonce: int, sequence: int, connect_nonce: int, ack_nonce: int | None = None) -> bytes:
|
|
167
|
+
"""Derive AAD for L01 protocol."""
|
|
168
|
+
return (
|
|
169
|
+
sequence.to_bytes(4, "big")
|
|
170
|
+
+ connect_nonce.to_bytes(4, "big")
|
|
171
|
+
+ (ack_nonce.to_bytes(4, "big") if ack_nonce is not None else b"")
|
|
172
|
+
+ nonce.to_bytes(4, "big")
|
|
173
|
+
+ timestamp.to_bytes(4, "big")
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
@staticmethod
|
|
177
|
+
def encrypt_gcm_l01(
|
|
178
|
+
plaintext: bytes,
|
|
179
|
+
local_key: str,
|
|
180
|
+
timestamp: int,
|
|
181
|
+
sequence: int,
|
|
182
|
+
nonce: int,
|
|
183
|
+
connect_nonce: int,
|
|
184
|
+
ack_nonce: int | None = None,
|
|
185
|
+
) -> bytes:
|
|
186
|
+
"""Encrypt plaintext for L01 protocol using AES-256-GCM."""
|
|
187
|
+
if not isinstance(plaintext, bytes):
|
|
188
|
+
raise TypeError("plaintext requires bytes")
|
|
189
|
+
|
|
190
|
+
key = Utils._l01_key(local_key, timestamp)
|
|
191
|
+
iv = Utils._l01_iv(timestamp, nonce, sequence)
|
|
192
|
+
aad = Utils._l01_aad(timestamp, nonce, sequence, connect_nonce, ack_nonce)
|
|
193
|
+
|
|
194
|
+
cipher = AES.new(key, AES.MODE_GCM, nonce=iv)
|
|
195
|
+
cipher.update(aad)
|
|
196
|
+
ciphertext, tag = cipher.encrypt_and_digest(plaintext)
|
|
197
|
+
|
|
198
|
+
return ciphertext + tag
|
|
199
|
+
|
|
200
|
+
@staticmethod
|
|
201
|
+
def decrypt_gcm_l01(
|
|
202
|
+
payload: bytes,
|
|
203
|
+
local_key: str,
|
|
204
|
+
timestamp: int,
|
|
205
|
+
sequence: int,
|
|
206
|
+
nonce: int,
|
|
207
|
+
connect_nonce: int,
|
|
208
|
+
ack_nonce: int,
|
|
209
|
+
) -> bytes:
|
|
210
|
+
"""Decrypt payload for L01 protocol using AES-256-GCM."""
|
|
211
|
+
if not isinstance(payload, bytes):
|
|
212
|
+
raise TypeError("payload requires bytes")
|
|
213
|
+
|
|
214
|
+
key = Utils._l01_key(local_key, timestamp)
|
|
215
|
+
iv = Utils._l01_iv(timestamp, nonce, sequence)
|
|
216
|
+
aad = Utils._l01_aad(timestamp, nonce, sequence, connect_nonce, ack_nonce)
|
|
217
|
+
|
|
218
|
+
if len(payload) < 16:
|
|
219
|
+
raise ValueError("Invalid payload length for GCM decryption")
|
|
220
|
+
|
|
221
|
+
tag = payload[-16:]
|
|
222
|
+
ciphertext = payload[:-16]
|
|
223
|
+
|
|
224
|
+
cipher = AES.new(key, AES.MODE_GCM, nonce=iv)
|
|
225
|
+
cipher.update(aad)
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
return cipher.decrypt_and_verify(ciphertext, tag)
|
|
229
|
+
except ValueError as e:
|
|
230
|
+
raise RoborockException("GCM tag verification failed") from e
|
|
231
|
+
|
|
232
|
+
@staticmethod
|
|
233
|
+
def crc(data: bytes) -> int:
|
|
234
|
+
"""Gather bytes for checksum calculation."""
|
|
235
|
+
return binascii.crc32(data)
|
|
236
|
+
|
|
237
|
+
@staticmethod
|
|
238
|
+
def decompress(compressed_data: bytes):
|
|
239
|
+
"""Decompress data using gzip."""
|
|
240
|
+
return gzip.decompress(compressed_data)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class EncryptionAdapter(Construct):
|
|
244
|
+
"""Adapter to handle communication encryption."""
|
|
245
|
+
|
|
246
|
+
def __init__(self, token_func: Callable):
|
|
247
|
+
super().__init__()
|
|
248
|
+
self.token_func = token_func
|
|
249
|
+
|
|
250
|
+
def _parse(self, stream, context, path):
|
|
251
|
+
subcon1 = Optional(Int16ub)
|
|
252
|
+
length = subcon1.parse_stream(stream, **context)
|
|
253
|
+
if not length:
|
|
254
|
+
if length == 0:
|
|
255
|
+
subcon1.parse_stream(stream, **context) # seek 2
|
|
256
|
+
return None
|
|
257
|
+
subcon2 = Bytes(length)
|
|
258
|
+
obj = subcon2.parse_stream(stream, **context)
|
|
259
|
+
return self._decode(obj, context, path)
|
|
260
|
+
|
|
261
|
+
def _build(self, obj, stream, context, path):
|
|
262
|
+
if obj is not None:
|
|
263
|
+
obj2 = self._encode(obj, context, path)
|
|
264
|
+
subcon1 = Int16ub
|
|
265
|
+
length = len(obj2)
|
|
266
|
+
subcon1.build_stream(length, stream, **context)
|
|
267
|
+
subcon2 = Bytes(length)
|
|
268
|
+
subcon2.build_stream(obj2, stream, **context)
|
|
269
|
+
return obj
|
|
270
|
+
|
|
271
|
+
def _encode(self, obj, context, _):
|
|
272
|
+
"""Encrypt the given payload with the token stored in the context.
|
|
273
|
+
|
|
274
|
+
:param obj: JSON object to encrypt
|
|
275
|
+
"""
|
|
276
|
+
if context.version == b"A01":
|
|
277
|
+
iv = md5hex(format(context.random, "08x") + A01_HASH)[8:24]
|
|
278
|
+
decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8"))
|
|
279
|
+
return decipher.encrypt(obj)
|
|
280
|
+
elif context.version == b"B01":
|
|
281
|
+
iv = md5hex(f"{context.random:08x}" + B01_HASH)[9:25]
|
|
282
|
+
decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8"))
|
|
283
|
+
return decipher.encrypt(pad(obj, AES.block_size))
|
|
284
|
+
elif context.version == b"L01":
|
|
285
|
+
return Utils.encrypt_gcm_l01(
|
|
286
|
+
plaintext=obj,
|
|
287
|
+
local_key=context.search("local_key"),
|
|
288
|
+
timestamp=context.timestamp,
|
|
289
|
+
sequence=context.seq,
|
|
290
|
+
nonce=context.random,
|
|
291
|
+
connect_nonce=context.search("connect_nonce"),
|
|
292
|
+
ack_nonce=context.search("ack_nonce"),
|
|
293
|
+
)
|
|
294
|
+
token = self.token_func(context)
|
|
295
|
+
encrypted = Utils.encrypt_ecb(obj, token)
|
|
296
|
+
return encrypted
|
|
297
|
+
|
|
298
|
+
def _decode(self, obj, context, _):
|
|
299
|
+
"""Decrypts the given payload with the token stored in the context."""
|
|
300
|
+
if context.version == b"A01":
|
|
301
|
+
iv = md5hex(format(context.random, "08x") + A01_HASH)[8:24]
|
|
302
|
+
decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8"))
|
|
303
|
+
return decipher.decrypt(obj)
|
|
304
|
+
elif context.version == b"B01":
|
|
305
|
+
iv = md5hex(f"{context.random:08x}" + B01_HASH)[9:25]
|
|
306
|
+
decipher = AES.new(bytes(context.search("local_key"), "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8"))
|
|
307
|
+
return unpad(decipher.decrypt(obj), AES.block_size)
|
|
308
|
+
elif context.version == b"L01":
|
|
309
|
+
return Utils.decrypt_gcm_l01(
|
|
310
|
+
payload=obj,
|
|
311
|
+
local_key=context.search("local_key"),
|
|
312
|
+
timestamp=context.timestamp,
|
|
313
|
+
sequence=context.seq,
|
|
314
|
+
nonce=context.random,
|
|
315
|
+
connect_nonce=context.search("connect_nonce"),
|
|
316
|
+
ack_nonce=context.search("ack_nonce"),
|
|
317
|
+
)
|
|
318
|
+
token = self.token_func(context)
|
|
319
|
+
decrypted = Utils.decrypt_ecb(obj, token)
|
|
320
|
+
return decrypted
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class OptionalChecksum(Checksum):
|
|
324
|
+
def _parse(self, stream, context, path):
|
|
325
|
+
if not context.message.value.payload:
|
|
326
|
+
return
|
|
327
|
+
hash1 = self.checksumfield.parse_stream(stream, **context)
|
|
328
|
+
hash2 = self.hashfunc(self.bytesfunc(context))
|
|
329
|
+
if hash1 != hash2:
|
|
330
|
+
raise ChecksumError(
|
|
331
|
+
f"wrong checksum, read {hash1 if not isinstance(hash1, bytestringtype) else binascii.hexlify(hash1)}, "
|
|
332
|
+
f"computed {hash2 if not isinstance(hash2, bytestringtype) else binascii.hexlify(hash2)}",
|
|
333
|
+
path=path,
|
|
334
|
+
)
|
|
335
|
+
return hash1
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class PrefixedStruct(Struct):
|
|
339
|
+
def _parse(self, stream, context, path):
|
|
340
|
+
subcon1 = Peek(Optional(Bytes(3)))
|
|
341
|
+
peek_version = subcon1.parse_stream(stream, **context)
|
|
342
|
+
|
|
343
|
+
valid_versions = (b"1.0", b"A01", b"B01", b"L01")
|
|
344
|
+
if peek_version not in valid_versions:
|
|
345
|
+
# Current stream position does not start with a valid version.
|
|
346
|
+
# Scan forward to find one.
|
|
347
|
+
current_pos = stream_tell(stream, path)
|
|
348
|
+
# Read remaining data to find a valid header
|
|
349
|
+
data = stream.read()
|
|
350
|
+
|
|
351
|
+
if not data:
|
|
352
|
+
# EOF reached, let the parser fail naturally without logging
|
|
353
|
+
stream_seek(stream, current_pos, 0, path)
|
|
354
|
+
return super()._parse(stream, context, path)
|
|
355
|
+
|
|
356
|
+
start_index = -1
|
|
357
|
+
# Find the earliest occurrence of any valid version in a single pass
|
|
358
|
+
for i in range(len(data) - 2):
|
|
359
|
+
if data[i : i + 3] in valid_versions:
|
|
360
|
+
start_index = i
|
|
361
|
+
break
|
|
362
|
+
|
|
363
|
+
if start_index != -1:
|
|
364
|
+
# Found a valid version header at `start_index`.
|
|
365
|
+
# Seek to that position (original_pos + index).
|
|
366
|
+
if start_index != 4:
|
|
367
|
+
# 4 is the typical/expected amount we prune off,
|
|
368
|
+
# therefore, we only want a debug if we have a different length.
|
|
369
|
+
_LOGGER.debug("Stripping %d bytes of invalid data from stream", start_index)
|
|
370
|
+
stream_seek(stream, current_pos + start_index, 0, path)
|
|
371
|
+
else:
|
|
372
|
+
_LOGGER.debug("No valid version header found in stream, continuing anyways...")
|
|
373
|
+
# Seek back to the original position to avoid parsing at EOF
|
|
374
|
+
stream_seek(stream, current_pos, 0, path)
|
|
375
|
+
|
|
376
|
+
return super()._parse(stream, context, path)
|
|
377
|
+
|
|
378
|
+
def _build(self, obj, stream, context, path):
|
|
379
|
+
prefixed = context.search("prefixed")
|
|
380
|
+
if not prefixed:
|
|
381
|
+
return super()._build(obj, stream, context, path)
|
|
382
|
+
offset = stream_tell(stream, path)
|
|
383
|
+
stream_seek(stream, offset + 4, 0, path)
|
|
384
|
+
super()._build(obj, stream, context, path)
|
|
385
|
+
new_offset = stream_tell(stream, path)
|
|
386
|
+
subcon1 = Bytes(4)
|
|
387
|
+
stream_seek(stream, offset, 0, path)
|
|
388
|
+
subcon1.build_stream(new_offset - offset - subcon1.sizeof(**context), stream, **context)
|
|
389
|
+
stream_seek(stream, new_offset + 4, 0, path)
|
|
390
|
+
return obj
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
_Message = RawCopy(
|
|
394
|
+
Struct(
|
|
395
|
+
"version" / Bytes(3),
|
|
396
|
+
"seq" / Int32ub,
|
|
397
|
+
"random" / Int32ub,
|
|
398
|
+
"timestamp" / Int32ub,
|
|
399
|
+
"protocol" / Int16ub,
|
|
400
|
+
"payload"
|
|
401
|
+
/ EncryptionAdapter(
|
|
402
|
+
lambda ctx: Utils.md5(
|
|
403
|
+
Utils.encode_timestamp(ctx.timestamp) + Utils.ensure_bytes(ctx.search("local_key")) + SALT
|
|
404
|
+
),
|
|
405
|
+
),
|
|
406
|
+
)
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
_Messages = Struct(
|
|
410
|
+
"messages"
|
|
411
|
+
/ GreedyRange(
|
|
412
|
+
PrefixedStruct(
|
|
413
|
+
"message" / _Message,
|
|
414
|
+
"checksum" / OptionalChecksum(Optional(Int32ub), Utils.crc, lambda ctx: ctx.message.data),
|
|
415
|
+
)
|
|
416
|
+
),
|
|
417
|
+
"remaining" / Optional(GreedyBytes),
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
class _Parser:
|
|
422
|
+
def __init__(self, con: Construct, required_local_key: bool):
|
|
423
|
+
self.con = con
|
|
424
|
+
self.required_local_key = required_local_key
|
|
425
|
+
|
|
426
|
+
def parse(
|
|
427
|
+
self, data: bytes, local_key: str | None = None, connect_nonce: int | None = None, ack_nonce: int | None = None
|
|
428
|
+
) -> tuple[list[RoborockMessage], bytes]:
|
|
429
|
+
if self.required_local_key and local_key is None:
|
|
430
|
+
raise RoborockException("Local key is required")
|
|
431
|
+
parsed = self.con.parse(data, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce)
|
|
432
|
+
parsed_messages = [Container({"message": parsed.message})] if parsed.get("message") else parsed.messages
|
|
433
|
+
messages = []
|
|
434
|
+
for message in parsed_messages:
|
|
435
|
+
messages.append(
|
|
436
|
+
RoborockMessage(
|
|
437
|
+
version=message.message.value.version,
|
|
438
|
+
seq=message.message.value.get("seq"),
|
|
439
|
+
random=message.message.value.get("random"),
|
|
440
|
+
timestamp=message.message.value.get("timestamp"),
|
|
441
|
+
protocol=message.message.value.get("protocol"),
|
|
442
|
+
payload=message.message.value.payload,
|
|
443
|
+
)
|
|
444
|
+
)
|
|
445
|
+
remaining = parsed.get("remaining") or b""
|
|
446
|
+
return messages, remaining
|
|
447
|
+
|
|
448
|
+
def build(
|
|
449
|
+
self,
|
|
450
|
+
roborock_messages: list[RoborockMessage] | RoborockMessage,
|
|
451
|
+
local_key: str,
|
|
452
|
+
prefixed: bool = True,
|
|
453
|
+
connect_nonce: int | None = None,
|
|
454
|
+
ack_nonce: int | None = None,
|
|
455
|
+
) -> bytes:
|
|
456
|
+
if isinstance(roborock_messages, RoborockMessage):
|
|
457
|
+
roborock_messages = [roborock_messages]
|
|
458
|
+
messages = []
|
|
459
|
+
for roborock_message in roborock_messages:
|
|
460
|
+
messages.append(
|
|
461
|
+
{
|
|
462
|
+
"message": {
|
|
463
|
+
"value": {
|
|
464
|
+
"version": roborock_message.version,
|
|
465
|
+
"seq": roborock_message.seq,
|
|
466
|
+
"random": roborock_message.random,
|
|
467
|
+
"timestamp": roborock_message.timestamp,
|
|
468
|
+
"protocol": roborock_message.protocol,
|
|
469
|
+
"payload": roborock_message.payload,
|
|
470
|
+
}
|
|
471
|
+
},
|
|
472
|
+
}
|
|
473
|
+
)
|
|
474
|
+
return self.con.build(
|
|
475
|
+
{"messages": [message for message in messages], "remaining": b""},
|
|
476
|
+
local_key=local_key,
|
|
477
|
+
prefixed=prefixed,
|
|
478
|
+
connect_nonce=connect_nonce,
|
|
479
|
+
ack_nonce=ack_nonce,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
MessageParser: _Parser = _Parser(_Messages, True)
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def create_mqtt_params(rriot: RRiot) -> MqttParams:
|
|
487
|
+
"""Return the MQTT parameters for this user."""
|
|
488
|
+
url = urlparse(rriot.r.m)
|
|
489
|
+
if not isinstance(url.hostname, str):
|
|
490
|
+
raise RoborockException(f"Url parsing '{rriot.r.m}' returned an invalid hostname")
|
|
491
|
+
if not url.port:
|
|
492
|
+
raise RoborockException(f"Url parsing '{rriot.r.m}' returned an invalid port")
|
|
493
|
+
hashed_user = md5hex(rriot.u + ":" + rriot.k)[2:10]
|
|
494
|
+
hashed_password = md5hex(rriot.s + ":" + rriot.k)[16:]
|
|
495
|
+
return MqttParams(
|
|
496
|
+
host=str(url.hostname),
|
|
497
|
+
port=url.port,
|
|
498
|
+
tls=(url.scheme == "ssl"),
|
|
499
|
+
username=hashed_user,
|
|
500
|
+
password=hashed_password,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
Decoder = Callable[[bytes], list[RoborockMessage]]
|
|
505
|
+
Encoder = Callable[[RoborockMessage], bytes]
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def create_mqtt_decoder(local_key: str) -> Decoder:
|
|
509
|
+
"""Create a decoder for MQTT messages."""
|
|
510
|
+
|
|
511
|
+
def decode(data: bytes) -> list[RoborockMessage]:
|
|
512
|
+
"""Parse the given data into Roborock messages."""
|
|
513
|
+
messages, _ = MessageParser.parse(data, local_key)
|
|
514
|
+
return messages
|
|
515
|
+
|
|
516
|
+
return decode
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def create_mqtt_encoder(local_key: str) -> Encoder:
|
|
520
|
+
"""Create an encoder for MQTT messages."""
|
|
521
|
+
|
|
522
|
+
def encode(messages: RoborockMessage) -> bytes:
|
|
523
|
+
"""Build the given Roborock messages into a byte string."""
|
|
524
|
+
return MessageParser.build(messages, local_key, prefixed=False)
|
|
525
|
+
|
|
526
|
+
return encode
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def create_local_decoder(local_key: str, connect_nonce: int | None = None, ack_nonce: int | None = None) -> Decoder:
|
|
530
|
+
"""Create a decoder for local API messages."""
|
|
531
|
+
|
|
532
|
+
# This buffer is used to accumulate bytes until a complete message can be parsed.
|
|
533
|
+
# It is defined outside the decode function to maintain state across calls.
|
|
534
|
+
buffer: bytes = b""
|
|
535
|
+
|
|
536
|
+
def decode(bytes_data: bytes) -> list[RoborockMessage]:
|
|
537
|
+
"""Parse the given data into Roborock messages."""
|
|
538
|
+
nonlocal buffer
|
|
539
|
+
buffer += bytes_data
|
|
540
|
+
parsed_messages, remaining = MessageParser.parse(
|
|
541
|
+
buffer, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce
|
|
542
|
+
)
|
|
543
|
+
if remaining:
|
|
544
|
+
_LOGGER.debug("Found %d extra bytes: %s", len(remaining), remaining)
|
|
545
|
+
buffer = remaining
|
|
546
|
+
return parsed_messages
|
|
547
|
+
|
|
548
|
+
return decode
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def create_local_encoder(local_key: str, connect_nonce: int | None = None, ack_nonce: int | None = None) -> Encoder:
|
|
552
|
+
"""Create an encoder for local API messages."""
|
|
553
|
+
|
|
554
|
+
def encode(message: RoborockMessage) -> bytes:
|
|
555
|
+
"""Called when data is sent to the transport."""
|
|
556
|
+
return MessageParser.build(message, local_key=local_key, connect_nonce=connect_nonce, ack_nonce=ack_nonce)
|
|
557
|
+
|
|
558
|
+
return encode
|