python-roborock 2.35.0__tar.gz → 2.37.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {python_roborock-2.35.0 → python_roborock-2.37.0}/PKG-INFO +1 -1
  2. {python_roborock-2.35.0 → python_roborock-2.37.0}/pyproject.toml +1 -1
  3. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/api.py +1 -6
  4. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/containers.py +23 -0
  5. python_roborock-2.37.0/roborock/devices/a01_channel.py +93 -0
  6. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/b01_channel.py +2 -5
  7. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/local_channel.py +11 -40
  8. python_roborock-2.37.0/roborock/devices/mqtt_channel.py +95 -0
  9. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/traits/b01/props.py +2 -3
  10. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/v1_channel.py +1 -1
  11. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/v1_rpc_channel.py +38 -10
  12. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/protocol.py +20 -0
  13. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/protocols/v1_protocol.py +69 -33
  14. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/roborock_message.py +0 -10
  15. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/version_1_apis/roborock_client_v1.py +19 -17
  16. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/version_1_apis/roborock_local_client_v1.py +24 -17
  17. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/version_1_apis/roborock_mqtt_client_v1.py +11 -14
  18. python_roborock-2.35.0/roborock/devices/a01_channel.py +0 -43
  19. python_roborock-2.35.0/roborock/devices/mqtt_channel.py +0 -137
  20. {python_roborock-2.35.0 → python_roborock-2.37.0}/LICENSE +0 -0
  21. {python_roborock-2.35.0 → python_roborock-2.37.0}/README.md +0 -0
  22. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/__init__.py +0 -0
  23. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/clean_modes.py +0 -0
  24. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/cli.py +0 -0
  25. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/cloud_api.py +0 -0
  26. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/code_mappings.py +0 -0
  27. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/command_cache.py +0 -0
  28. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/const.py +0 -0
  29. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/device_features.py +0 -0
  30. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/README.md +0 -0
  31. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/__init__.py +0 -0
  32. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/cache.py +0 -0
  33. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/channel.py +0 -0
  34. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/device.py +0 -0
  35. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/device_manager.py +0 -0
  36. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/traits/b01/__init__.py +0 -0
  37. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/traits/dyad.py +0 -0
  38. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/traits/status.py +0 -0
  39. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/traits/trait.py +0 -0
  40. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/devices/traits/zeo.py +0 -0
  41. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/exceptions.py +0 -0
  42. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/mqtt/__init__.py +0 -0
  43. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/mqtt/roborock_session.py +0 -0
  44. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/mqtt/session.py +0 -0
  45. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/protocols/a01_protocol.py +0 -0
  46. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/protocols/b01_protocol.py +0 -0
  47. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/py.typed +0 -0
  48. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/roborock_future.py +0 -0
  49. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/roborock_typing.py +0 -0
  50. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/util.py +0 -0
  51. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/version_1_apis/__init__.py +0 -0
  52. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/version_a01_apis/__init__.py +0 -0
  53. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/version_a01_apis/roborock_client_a01.py +0 -0
  54. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/version_a01_apis/roborock_mqtt_client_a01.py +0 -0
  55. {python_roborock-2.35.0 → python_roborock-2.37.0}/roborock/web_api.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: python-roborock
3
- Version: 2.35.0
3
+ Version: 2.37.0
4
4
  Summary: A package to control Roborock vacuums.
5
5
  Home-page: https://github.com/humbertogontijo/python-roborock
6
6
  License: GPL-3.0-only
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "python-roborock"
3
- version = "2.35.0"
3
+ version = "2.37.0"
4
4
  description = "A package to control Roborock vacuums."
5
5
  authors = ["humbertogontijo <humbertogontijo@users.noreply.github.com>"]
6
6
  license = "GPL-3.0-only"
@@ -3,9 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import asyncio
6
- import base64
7
6
  import logging
8
- import secrets
9
7
  import time
10
8
  from abc import ABC, abstractmethod
11
9
  from typing import Any
@@ -37,14 +35,11 @@ class RoborockClient(ABC):
37
35
  def __init__(self, device_info: DeviceData) -> None:
38
36
  """Initialize RoborockClient."""
39
37
  self.device_info = device_info
40
- self._nonce = secrets.token_bytes(16)
41
38
  self._waiting_queue: dict[int, RoborockFuture] = {}
42
39
  self._last_device_msg_in = time.monotonic()
43
40
  self._last_disconnection = time.monotonic()
44
41
  self.keep_alive = KEEPALIVE
45
- self._diagnostic_data: dict[str, dict[str, Any]] = {
46
- "misc_info": {"Nonce": base64.b64encode(self._nonce).decode("utf-8")}
47
- }
42
+ self._diagnostic_data: dict[str, dict[str, Any]] = {}
48
43
  self.is_available: bool = True
49
44
 
50
45
  async def async_release(self) -> None:
@@ -725,6 +725,29 @@ class NetworkInfo(RoborockBase):
725
725
  rssi: int | None = None
726
726
 
727
727
 
728
+ @dataclass
729
+ class AppInitStatusLocalInfo(RoborockBase):
730
+ location: str
731
+ bom: str | None = None
732
+ featureset: int | None = None
733
+ language: str | None = None
734
+ logserver: str | None = None
735
+ wifiplan: str | None = None
736
+ timezone: str | None = None
737
+ name: str | None = None
738
+
739
+
740
+ @dataclass
741
+ class AppInitStatus(RoborockBase):
742
+ local_info: AppInitStatusLocalInfo
743
+ feature_info: list[int]
744
+ new_feature_info: int
745
+ new_feature_info_str: str
746
+ new_feature_info_2: int | None = None
747
+ carriage_type: int | None = None
748
+ dsp_version: int | None = None
749
+
750
+
728
751
  @dataclass
729
752
  class DeviceData(RoborockBase):
730
753
  device: HomeDataDevice
@@ -0,0 +1,93 @@
1
+ """Thin wrapper around the MQTT channel for Roborock A01 devices."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from typing import Any, overload
6
+
7
+ from roborock.exceptions import RoborockException
8
+ from roborock.protocols.a01_protocol import (
9
+ decode_rpc_response,
10
+ encode_mqtt_payload,
11
+ )
12
+ from roborock.roborock_message import (
13
+ RoborockDyadDataProtocol,
14
+ RoborockMessage,
15
+ RoborockZeoProtocol,
16
+ )
17
+
18
+ from .mqtt_channel import MqttChannel
19
+
20
+ _LOGGER = logging.getLogger(__name__)
21
+ _TIMEOUT = 10.0
22
+
23
+ # Both RoborockDyadDataProtocol and RoborockZeoProtocol have the same
24
+ # value for ID_QUERY
25
+ _ID_QUERY = int(RoborockDyadDataProtocol.ID_QUERY)
26
+
27
+
28
+ @overload
29
+ async def send_decoded_command(
30
+ mqtt_channel: MqttChannel,
31
+ params: dict[RoborockDyadDataProtocol, Any],
32
+ ) -> dict[RoborockDyadDataProtocol, Any]:
33
+ ...
34
+
35
+
36
+ @overload
37
+ async def send_decoded_command(
38
+ mqtt_channel: MqttChannel,
39
+ params: dict[RoborockZeoProtocol, Any],
40
+ ) -> dict[RoborockZeoProtocol, Any]:
41
+ ...
42
+
43
+
44
+ async def send_decoded_command(
45
+ mqtt_channel: MqttChannel,
46
+ params: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any],
47
+ ) -> dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any]:
48
+ """Send a command on the MQTT channel and get a decoded response."""
49
+ _LOGGER.debug("Sending MQTT command: %s", params)
50
+ roborock_message = encode_mqtt_payload(params)
51
+
52
+ # For commands that set values: send the command and do not
53
+ # block waiting for a response. Queries are handled below.
54
+ param_values = {int(k): v for k, v in params.items()}
55
+ if not (query_values := param_values.get(_ID_QUERY)):
56
+ await mqtt_channel.publish(roborock_message)
57
+ return {}
58
+
59
+ # Merge any results together than contain the requested data. This
60
+ # does not use a future since it needs to merge results across responses.
61
+ # This could be simplified if we can assume there is a single response.
62
+ finished = asyncio.Event()
63
+ result: dict[int, Any] = {}
64
+
65
+ def find_response(response_message: RoborockMessage) -> None:
66
+ """Handle incoming messages and resolve the future."""
67
+ try:
68
+ decoded = decode_rpc_response(response_message)
69
+ except RoborockException as ex:
70
+ _LOGGER.info("Failed to decode a01 message: %s: %s", response_message, ex)
71
+ return
72
+ for key, value in decoded.items():
73
+ if key in query_values:
74
+ result[key] = value
75
+ if len(result) != len(query_values):
76
+ _LOGGER.debug("Incomplete query response: %s != %s", result, query_values)
77
+ return
78
+ _LOGGER.debug("Received query response: %s", result)
79
+ if not finished.is_set():
80
+ finished.set()
81
+
82
+ unsub = await mqtt_channel.subscribe(find_response)
83
+
84
+ try:
85
+ await mqtt_channel.publish(roborock_message)
86
+ try:
87
+ await asyncio.wait_for(finished.wait(), timeout=_TIMEOUT)
88
+ except TimeoutError as ex:
89
+ raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
90
+ finally:
91
+ unsub()
92
+
93
+ return result # type: ignore[return-value]
@@ -3,12 +3,10 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import logging
6
- from typing import Any
7
6
 
8
7
  from roborock.protocols.b01_protocol import (
9
8
  CommandType,
10
9
  ParamsType,
11
- decode_rpc_response,
12
10
  encode_mqtt_payload,
13
11
  )
14
12
 
@@ -22,9 +20,8 @@ async def send_decoded_command(
22
20
  dps: int,
23
21
  command: CommandType,
24
22
  params: ParamsType,
25
- ) -> dict[int, Any]:
23
+ ) -> None:
26
24
  """Send a command on the MQTT channel and get a decoded response."""
27
25
  _LOGGER.debug("Sending MQTT command: %s", params)
28
26
  roborock_message = encode_mqtt_payload(dps, command, params)
29
- response = await mqtt_channel.send_message(roborock_message)
30
- return decode_rpc_response(response) # type: ignore[return-value]
27
+ await mqtt_channel.publish(roborock_message)
@@ -4,7 +4,6 @@ import asyncio
4
4
  import logging
5
5
  from collections.abc import Callable
6
6
  from dataclasses import dataclass
7
- from json import JSONDecodeError
8
7
 
9
8
  from roborock.exceptions import RoborockConnectionException, RoborockException
10
9
  from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
@@ -46,11 +45,8 @@ class LocalChannel(Channel):
46
45
  self._subscribers: list[Callable[[RoborockMessage], None]] = []
47
46
  self._is_connected = False
48
47
 
49
- # RPC support
50
- self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {}
51
48
  self._decoder: Decoder = create_local_decoder(local_key)
52
49
  self._encoder: Encoder = create_local_encoder(local_key)
53
- self._queue_lock = asyncio.Lock()
54
50
 
55
51
  @property
56
52
  def is_connected(self) -> bool:
@@ -87,7 +83,6 @@ class LocalChannel(Channel):
87
83
  return
88
84
  for message in messages:
89
85
  _LOGGER.debug("Received message: %s", message)
90
- asyncio.create_task(self._resolve_future_with_lock(message))
91
86
  for callback in self._subscribers:
92
87
  try:
93
88
  callback(message)
@@ -109,48 +104,24 @@ class LocalChannel(Channel):
109
104
 
110
105
  return unsubscribe
111
106
 
112
- async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
113
- """Resolve waiting future with proper locking."""
114
- if (request_id := message.get_request_id()) is None:
115
- _LOGGER.debug("Received message with no request_id")
116
- return
117
- async with self._queue_lock:
118
- if (future := self._waiting_queue.pop(request_id, None)) is not None:
119
- future.set_result(message)
120
- else:
121
- _LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
122
-
123
- async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
124
- """Send a command message and wait for the response message."""
107
+ async def publish(self, message: RoborockMessage) -> None:
108
+ """Send a command message.
109
+
110
+ The caller is responsible for associating the message with its response.
111
+ """
125
112
  if not self._transport or not self._is_connected:
126
113
  raise RoborockConnectionException("Not connected to device")
127
114
 
128
- try:
129
- if (request_id := message.get_request_id()) is None:
130
- raise RoborockException("Message must have a request_id for RPC calls")
131
- except (ValueError, JSONDecodeError) as err:
132
- _LOGGER.exception("Error getting request_id from message: %s", err)
133
- raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
134
-
135
- future: asyncio.Future[RoborockMessage] = asyncio.Future()
136
- async with self._queue_lock:
137
- if request_id in self._waiting_queue:
138
- raise RoborockException(f"Request ID {request_id} already pending, cannot send command")
139
- self._waiting_queue[request_id] = future
140
-
141
115
  try:
142
116
  encoded_msg = self._encoder(message)
117
+ except Exception as err:
118
+ _LOGGER.exception("Error encoding MQTT message: %s", err)
119
+ raise RoborockException(f"Failed to encode MQTT message: {err}") from err
120
+ try:
143
121
  self._transport.write(encoded_msg)
144
- return await asyncio.wait_for(future, timeout=timeout)
145
- except asyncio.TimeoutError as ex:
146
- async with self._queue_lock:
147
- self._waiting_queue.pop(request_id, None)
148
- raise RoborockException(f"Command timed out after {timeout}s") from ex
149
- except Exception:
122
+ except Exception as err:
150
123
  logging.exception("Uncaught error sending command")
151
- async with self._queue_lock:
152
- self._waiting_queue.pop(request_id, None)
153
- raise
124
+ raise RoborockException(f"Failed to send message: {message}") from err
154
125
 
155
126
 
156
127
  # This module provides a factory function to create LocalChannel instances.
@@ -0,0 +1,95 @@
1
+ """Modules for communicating with specific Roborock devices over MQTT."""
2
+
3
+ import logging
4
+ from collections.abc import Callable
5
+
6
+ from roborock.containers import HomeDataDevice, RRiot, UserData
7
+ from roborock.exceptions import RoborockException
8
+ from roborock.mqtt.session import MqttParams, MqttSession, MqttSessionException
9
+ from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder
10
+ from roborock.roborock_message import RoborockMessage
11
+
12
+ from .channel import Channel
13
+
14
+ _LOGGER = logging.getLogger(__name__)
15
+
16
+
17
+ class MqttChannel(Channel):
18
+ """Simple RPC-style channel for communicating with a device over MQTT.
19
+
20
+ Handles request/response correlation and timeouts, but leaves message
21
+ format most parsing to higher-level components.
22
+ """
23
+
24
+ def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot: RRiot, mqtt_params: MqttParams):
25
+ self._mqtt_session = mqtt_session
26
+ self._duid = duid
27
+ self._local_key = local_key
28
+ self._rriot = rriot
29
+ self._mqtt_params = mqtt_params
30
+
31
+ self._decoder = create_mqtt_decoder(local_key)
32
+ self._encoder = create_mqtt_encoder(local_key)
33
+
34
+ @property
35
+ def is_connected(self) -> bool:
36
+ """Return true if the channel is connected.
37
+
38
+ This passes through the underlying MQTT session's connected state.
39
+ """
40
+ return self._mqtt_session.connected
41
+
42
+ @property
43
+ def _publish_topic(self) -> str:
44
+ """Topic to send commands to the device."""
45
+ return f"rr/m/i/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}"
46
+
47
+ @property
48
+ def _subscribe_topic(self) -> str:
49
+ """Topic to receive responses from the device."""
50
+ return f"rr/m/o/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}"
51
+
52
+ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
53
+ """Subscribe to the device's response topic.
54
+
55
+ The callback will be called with the message payload when a message is received.
56
+
57
+ Returns a callable that can be used to unsubscribe from the topic.
58
+ """
59
+
60
+ def message_handler(payload: bytes) -> None:
61
+ if not (messages := self._decoder(payload)):
62
+ _LOGGER.warning("Failed to decode MQTT message: %s", payload)
63
+ return
64
+ for message in messages:
65
+ _LOGGER.debug("Received message: %s", message)
66
+ try:
67
+ callback(message)
68
+ except Exception as e:
69
+ _LOGGER.exception("Uncaught error in message handler callback: %s", e)
70
+
71
+ return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
72
+
73
+ async def publish(self, message: RoborockMessage) -> None:
74
+ """Publish a command message.
75
+
76
+ The caller is responsible for handling any responses and associating them
77
+ with the incoming request.
78
+ """
79
+ try:
80
+ encoded_msg = self._encoder(message)
81
+ except Exception as e:
82
+ _LOGGER.exception("Error encoding MQTT message: %s", e)
83
+ raise RoborockException(f"Failed to encode MQTT message: {e}") from e
84
+ try:
85
+ return await self._mqtt_session.publish(self._publish_topic, encoded_msg)
86
+ except MqttSessionException as e:
87
+ _LOGGER.exception("Error publishing MQTT message: %s", e)
88
+ raise RoborockException(f"Failed to publish MQTT message: {e}") from e
89
+
90
+
91
+ def create_mqtt_channel(
92
+ user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice
93
+ ) -> MqttChannel:
94
+ """Create a V1Channel for the given device."""
95
+ return MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import Any
5
4
 
6
5
  from roborock import RoborockB01Methods
7
6
  from roborock.roborock_message import RoborockB01Props
@@ -26,6 +25,6 @@ class B01PropsApi(Trait):
26
25
  """Initialize the B01Props API."""
27
26
  self._channel = channel
28
27
 
29
- async def query_values(self, props: list[RoborockB01Props]) -> dict[int, Any]:
28
+ async def query_values(self, props: list[RoborockB01Props]) -> None:
30
29
  """Query the device for the values of the given Dyad protocols."""
31
- return await send_decoded_command(self._channel, dps=10000, command=RoborockB01Methods.GET_PROP, params=props)
30
+ await send_decoded_command(self._channel, dps=10000, command=RoborockB01Methods.GET_PROP, params=props)
@@ -79,7 +79,7 @@ class V1Channel(Channel):
79
79
  @property
80
80
  def is_mqtt_connected(self) -> bool:
81
81
  """Return whether MQTT connection is available."""
82
- return self._mqtt_unsub is not None
82
+ return self._mqtt_unsub is not None and self._mqtt_channel.is_connected
83
83
 
84
84
  @property
85
85
  def rpc_channel(self) -> V1RpcChannel:
@@ -6,25 +6,27 @@ a simple interface for sending commands and receiving responses over both MQTT
6
6
  and local connections, preferring local when available.
7
7
  """
8
8
 
9
+ import asyncio
9
10
  import logging
10
11
  from collections.abc import Callable
11
12
  from typing import Any, Protocol, TypeVar, overload
12
13
 
13
14
  from roborock.containers import RoborockBase
15
+ from roborock.exceptions import RoborockException
14
16
  from roborock.protocols.v1_protocol import (
15
17
  CommandType,
16
18
  ParamsType,
19
+ RequestMessage,
17
20
  SecurityData,
18
- create_mqtt_payload_encoder,
19
21
  decode_rpc_response,
20
- encode_local_payload,
21
22
  )
22
- from roborock.roborock_message import RoborockMessage
23
+ from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
23
24
 
24
25
  from .local_channel import LocalChannel
25
26
  from .mqtt_channel import MqttChannel
26
27
 
27
28
  _LOGGER = logging.getLogger(__name__)
29
+ _TIMEOUT = 10.0
28
30
 
29
31
 
30
32
  _T = TypeVar("_T", bound=RoborockBase)
@@ -116,7 +118,7 @@ class PayloadEncodedV1RpcChannel(BaseV1RpcChannel):
116
118
  self,
117
119
  name: str,
118
120
  channel: MqttChannel | LocalChannel,
119
- payload_encoder: Callable[[CommandType, ParamsType], RoborockMessage],
121
+ payload_encoder: Callable[[RequestMessage], RoborockMessage],
120
122
  ) -> None:
121
123
  """Initialize the channel with a raw channel and an encoder function."""
122
124
  self._name = name
@@ -131,18 +133,44 @@ class PayloadEncodedV1RpcChannel(BaseV1RpcChannel):
131
133
  ) -> Any:
132
134
  """Send a command and return a parsed response RoborockBase type."""
133
135
  _LOGGER.debug("Sending command (%s): %s, params=%s", self._name, method, params)
134
- message = self._payload_encoder(method, params)
135
- response = await self._channel.send_message(message)
136
- return decode_rpc_response(response)
136
+ request_message = RequestMessage(method, params=params)
137
+ message = self._payload_encoder(request_message)
138
+
139
+ future: asyncio.Future[dict[str, Any]] = asyncio.Future()
140
+
141
+ def find_response(response_message: RoborockMessage) -> None:
142
+ try:
143
+ decoded = decode_rpc_response(response_message)
144
+ except RoborockException:
145
+ return
146
+ if decoded.request_id == request_message.request_id:
147
+ future.set_result(decoded.data)
148
+
149
+ unsub = await self._channel.subscribe(find_response)
150
+ try:
151
+ await self._channel.publish(message)
152
+ return await asyncio.wait_for(future, timeout=_TIMEOUT)
153
+ except TimeoutError as ex:
154
+ future.cancel()
155
+ raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
156
+ finally:
157
+ unsub()
137
158
 
138
159
 
139
160
  def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel:
140
161
  """Create a V1 RPC channel using an MQTT channel."""
141
- payload_encoder = create_mqtt_payload_encoder(security_data)
142
- return PayloadEncodedV1RpcChannel("mqtt", mqtt_channel, payload_encoder)
162
+ return PayloadEncodedV1RpcChannel(
163
+ "mqtt",
164
+ mqtt_channel,
165
+ lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data),
166
+ )
143
167
 
144
168
 
145
169
  def create_combined_rpc_channel(local_channel: LocalChannel, mqtt_rpc_channel: V1RpcChannel) -> V1RpcChannel:
146
170
  """Create a V1 RPC channel that combines local and MQTT channels."""
147
- local_rpc_channel = PayloadEncodedV1RpcChannel("local", local_channel, encode_local_payload)
171
+ local_rpc_channel = PayloadEncodedV1RpcChannel(
172
+ "local",
173
+ local_channel,
174
+ lambda x: x.encode_message(RoborockMessageProtocol.GENERAL_REQUEST),
175
+ )
148
176
  return CombinedV1RpcChannel(local_channel, local_rpc_channel, mqtt_rpc_channel)
@@ -147,6 +147,26 @@ class Utils:
147
147
  return unpad(decipher.decrypt(ciphertext), AES.block_size)
148
148
  return ciphertext
149
149
 
150
+ @staticmethod
151
+ def encrypt_cbc(plaintext: bytes, token: bytes) -> bytes:
152
+ """Encrypt plaintext with a given token using cbc mode.
153
+
154
+ This is currently used for testing purposes only.
155
+
156
+ :param bytes plaintext: Plaintext (json) to encrypt
157
+ :param bytes token: Token to use
158
+ :return: Encrypted bytes
159
+ """
160
+ if not isinstance(plaintext, bytes):
161
+ raise TypeError("plaintext requires bytes")
162
+ Utils.verify_token(token)
163
+ iv = bytes(AES.block_size)
164
+ cipher = AES.new(token, AES.MODE_CBC, iv)
165
+ if plaintext:
166
+ plaintext = pad(plaintext, AES.block_size)
167
+ return cipher.encrypt(plaintext)
168
+ return plaintext
169
+
150
170
  @staticmethod
151
171
  def decrypt_cbc(ciphertext: bytes, token: bytes) -> bytes:
152
172
  """Decrypt ciphertext with a given token using cbc mode.
@@ -7,6 +7,7 @@ import json
7
7
  import logging
8
8
  import math
9
9
  import secrets
10
+ import struct
10
11
  import time
11
12
  from collections.abc import Callable
12
13
  from dataclasses import dataclass, field
@@ -24,8 +25,6 @@ _LOGGER = logging.getLogger(__name__)
24
25
  __all__ = [
25
26
  "SecurityData",
26
27
  "create_security_data",
27
- "create_mqtt_payload_encoder",
28
- "encode_local_payload",
29
28
  "decode_rpc_response",
30
29
  ]
31
30
 
@@ -44,6 +43,10 @@ class SecurityData:
44
43
  """Convert security data to a dictionary for sending in the payload."""
45
44
  return {"security": {"endpoint": self.endpoint, "nonce": self.nonce.hex().lower()}}
46
45
 
46
+ def to_diagnostic_data(self) -> dict[str, Any]:
47
+ """Convert security data to a dictionary for debugging purposes."""
48
+ return {"nonce": self.nonce.hex().lower()}
49
+
47
50
 
48
51
  def create_security_data(rriot: RRiot) -> SecurityData:
49
52
  """Create a SecurityData instance for the given endpoint and nonce."""
@@ -61,7 +64,19 @@ class RequestMessage:
61
64
  timestamp: int = field(default_factory=lambda: math.floor(time.time()))
62
65
  request_id: int = field(default_factory=lambda: get_next_int(10000, 32767))
63
66
 
64
- def as_payload(self, security_data: SecurityData | None) -> bytes:
67
+ def encode_message(
68
+ self,
69
+ protocol: RoborockMessageProtocol,
70
+ security_data: SecurityData | None = None,
71
+ ) -> RoborockMessage:
72
+ """Convert the request message to a RoborockMessage."""
73
+ return RoborockMessage(
74
+ timestamp=self.timestamp,
75
+ protocol=protocol,
76
+ payload=self._as_payload(security_data=security_data),
77
+ )
78
+
79
+ def _as_payload(self, security_data: SecurityData | None) -> bytes:
65
80
  """Convert the request arguments to a dictionary."""
66
81
  inner = {
67
82
  "id": self.request_id,
@@ -80,36 +95,18 @@ class RequestMessage:
80
95
  )
81
96
 
82
97
 
83
- def create_mqtt_payload_encoder(security_data: SecurityData) -> Callable[[CommandType, ParamsType], RoborockMessage]:
84
- """Create a payload encoder for V1 commands over MQTT."""
85
-
86
- def _get_payload(method: CommandType, params: ParamsType) -> RoborockMessage:
87
- """Build the payload for a V1 command."""
88
- request = RequestMessage(method=method, params=params)
89
- payload = request.as_payload(security_data) # always secure
90
- return RoborockMessage(
91
- timestamp=request.timestamp,
92
- protocol=RoborockMessageProtocol.RPC_REQUEST,
93
- payload=payload,
94
- )
95
-
96
- return _get_payload
97
-
98
+ @dataclass(kw_only=True, frozen=True)
99
+ class ResponseMessage:
100
+ """Data structure for v1 RoborockMessage responses."""
98
101
 
99
- def encode_local_payload(method: CommandType, params: ParamsType) -> RoborockMessage:
100
- """Encode payload for V1 commands over local connection."""
102
+ request_id: int | None
103
+ """The request ID of the response."""
101
104
 
102
- request = RequestMessage(method=method, params=params)
103
- payload = request.as_payload(security_data=None)
105
+ data: dict[str, Any]
106
+ """The data of the response."""
104
107
 
105
- return RoborockMessage(
106
- timestamp=request.timestamp,
107
- protocol=RoborockMessageProtocol.GENERAL_REQUEST,
108
- payload=payload,
109
- )
110
108
 
111
-
112
- def decode_rpc_response(message: RoborockMessage) -> dict[str, Any]:
109
+ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage:
113
110
  """Decode a V1 RPC_RESPONSE message."""
114
111
  if not message.payload:
115
112
  raise RoborockException("Invalid V1 message format: missing payload")
@@ -123,14 +120,19 @@ def decode_rpc_response(message: RoborockMessage) -> dict[str, Any]:
123
120
  if not isinstance(datapoints, dict):
124
121
  raise RoborockException(f"Invalid V1 message format: 'dps' should be a dictionary for {message.payload!r}")
125
122
 
126
- if not (data_point := datapoints.get("102")):
127
- raise RoborockException("Invalid V1 message format: missing '102' data point")
123
+ if not (data_point := datapoints.get(str(RoborockMessageProtocol.RPC_RESPONSE))):
124
+ raise RoborockException(
125
+ f"Invalid V1 message format: missing '{RoborockMessageProtocol.RPC_RESPONSE}' data point"
126
+ )
128
127
 
129
128
  try:
130
129
  data_point_response = json.loads(data_point)
131
130
  except (json.JSONDecodeError, TypeError) as e:
132
- raise RoborockException(f"Invalid V1 message data point '102': {e} for {message.payload!r}") from e
131
+ raise RoborockException(
132
+ f"Invalid V1 message data point '{RoborockMessageProtocol.RPC_RESPONSE}': {e} for {message.payload!r}"
133
+ ) from e
133
134
 
135
+ request_id: int | None = data_point_response.get("id")
134
136
  if error := data_point_response.get("error"):
135
137
  raise RoborockException(f"Error in message: {error}")
136
138
 
@@ -141,4 +143,38 @@ def decode_rpc_response(message: RoborockMessage) -> dict[str, Any]:
141
143
  result = result[0]
142
144
  if not isinstance(result, dict):
143
145
  raise RoborockException(f"Invalid V1 message format: 'result' should be a dictionary for {message.payload!r}")
144
- return result
146
+ return ResponseMessage(request_id=request_id, data=result)
147
+
148
+
149
+ @dataclass
150
+ class MapResponse:
151
+ """Data structure for the V1 Map response."""
152
+
153
+ request_id: int
154
+ """The request ID of the map response."""
155
+
156
+ data: bytes
157
+ """The map data, decrypted and decompressed."""
158
+
159
+
160
+ def create_map_response_decoder(security_data: SecurityData) -> Callable[[RoborockMessage], MapResponse]:
161
+ """Create a decoder for V1 map response messages."""
162
+
163
+ def _decode_map_response(message: RoborockMessage) -> MapResponse:
164
+ """Decode a V1 map response message."""
165
+ if not message.payload or len(message.payload) < 24:
166
+ raise RoborockException("Invalid V1 map response format: missing payload")
167
+ header, body = message.payload[:24], message.payload[24:]
168
+ [endpoint, _, request_id, _] = struct.unpack("<8s8sH6s", header)
169
+ if not endpoint.decode().startswith(security_data.endpoint):
170
+ raise RoborockException(
171
+ f"Invalid V1 map response endpoint: {endpoint!r}, expected {security_data.endpoint!r}"
172
+ )
173
+ try:
174
+ decrypted = Utils.decrypt_cbc(body, security_data.nonce)
175
+ except ValueError as err:
176
+ raise RoborockException("Failed to decode map message payload") from err
177
+ decompressed = Utils.decompress(decrypted)
178
+ return MapResponse(request_id=request_id, data=decompressed)
179
+
180
+ return _decode_map_response
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import json
4
3
  import math
5
4
  import time
6
5
  from dataclasses import dataclass, field
@@ -247,12 +246,3 @@ class RoborockMessage:
247
246
  version: bytes = b"1.0"
248
247
  random: int = field(default_factory=lambda: get_next_int(10000, 99999))
249
248
  timestamp: int = field(default_factory=lambda: math.floor(time.time()))
250
-
251
- def get_request_id(self) -> int | None:
252
- if self.payload:
253
- payload = json.loads(self.payload.decode())
254
- for data_point_number, data_point in payload.get("dps").items():
255
- if data_point_number in ["101", "102"]:
256
- data_point_response = json.loads(data_point)
257
- return data_point_response.get("id")
258
- return None
@@ -1,13 +1,13 @@
1
1
  import asyncio
2
2
  import dataclasses
3
3
  import json
4
- import struct
5
4
  import time
6
5
  from abc import ABC, abstractmethod
7
6
  from collections.abc import Callable, Coroutine
8
7
  from typing import Any, TypeVar, final
9
8
 
10
9
  from roborock import (
10
+ AppInitStatus,
11
11
  DeviceProp,
12
12
  DockSummary,
13
13
  RoborockCommand,
@@ -45,7 +45,7 @@ from roborock.containers import (
45
45
  ValleyElectricityTimer,
46
46
  WashTowelMode,
47
47
  )
48
- from roborock.protocol import Utils
48
+ from roborock.protocols.v1_protocol import MapResponse, SecurityData, create_map_response_decoder
49
49
  from roborock.roborock_message import (
50
50
  ROBOROCK_DATA_CONSUMABLE_PROTOCOL,
51
51
  ROBOROCK_DATA_STATUS_PROTOCOL,
@@ -150,10 +150,15 @@ class RoborockClientV1(RoborockClient, ABC):
150
150
  """Roborock client base class for version 1 devices."""
151
151
 
152
152
  _listeners: dict[str, ListenerModel] = {}
153
+ _map_response_decoder: Callable[[RoborockMessage], MapResponse] | None = None
153
154
 
154
- def __init__(self, device_info: DeviceData, endpoint: str):
155
+ def __init__(self, device_info: DeviceData, security_data: SecurityData | None) -> None:
155
156
  """Initializes the Roborock client."""
156
157
  super().__init__(device_info)
158
+ if security_data is not None:
159
+ self._diagnostic_data.update({"misc_info": security_data.to_diagnostic_data()})
160
+ self._map_response_decoder = create_map_response_decoder(security_data)
161
+
157
162
  self._status_type: type[Status] = ModelStatus.get(device_info.model, S7MaxVStatus)
158
163
  self.cache: dict[CacheableAttribute, AttributeCache] = {
159
164
  cacheable_attribute: AttributeCache(attr, self._send_command)
@@ -162,7 +167,6 @@ class RoborockClientV1(RoborockClient, ABC):
162
167
  if device_info.device.duid not in self._listeners:
163
168
  self._listeners[device_info.device.duid] = ListenerModel({}, self.cache)
164
169
  self.listener_model = self._listeners[device_info.device.duid]
165
- self._endpoint = endpoint
166
170
 
167
171
  async def async_release(self) -> None:
168
172
  await super().async_release()
@@ -339,6 +343,10 @@ class RoborockClientV1(RoborockClient, ABC):
339
343
  """Load the map into the vacuum's memory."""
340
344
  await self.send_command(RoborockCommand.LOAD_MULTI_MAP, [map_flag])
341
345
 
346
+ async def get_app_init_status(self) -> AppInitStatus:
347
+ """Gets the app init status (needed for determining vacuum capabilities)."""
348
+ return await self.send_command(RoborockCommand.APP_GET_INIT_STATUS, return_type=AppInitStatus)
349
+
342
350
  @abstractmethod
343
351
  async def _send_command(
344
352
  self,
@@ -429,21 +437,15 @@ class RoborockClientV1(RoborockClient, ABC):
429
437
  dps = {data_point_number: data_point}
430
438
  self._logger.debug(f"Got unknown data point {dps}")
431
439
  elif data.payload and protocol == RoborockMessageProtocol.MAP_RESPONSE:
432
- payload = data.payload[0:24]
433
- [endpoint, _, request_id, _] = struct.unpack("<8s8sH6s", payload)
434
- if endpoint.decode().startswith(self._endpoint):
435
- try:
436
- decrypted = Utils.decrypt_cbc(data.payload[24:], self._nonce)
437
- except ValueError as err:
438
- raise RoborockException(f"Failed to decode {data.payload!r} for {data.protocol}") from err
439
- decompressed = Utils.decompress(decrypted)
440
- queue = self._waiting_queue.get(request_id)
440
+ if self._map_response_decoder is not None:
441
+ map_response = self._map_response_decoder(data)
442
+ queue = self._waiting_queue.get(map_response.request_id)
441
443
  if queue:
442
- if isinstance(decompressed, list):
443
- decompressed = decompressed[0]
444
- queue.set_result(decompressed)
444
+ queue.set_result(map_response.data)
445
445
  else:
446
- self._logger.debug("Received response for unknown request id %s", request_id)
446
+ self._logger.debug(
447
+ "Received unsolicited map response for request_id %s", map_response.request_id
448
+ )
447
449
  else:
448
450
  queue = self._waiting_queue.get(data.seq)
449
451
  if queue:
@@ -10,7 +10,7 @@ from .. import CommandVacuumError, DeviceData, RoborockCommand
10
10
  from ..api import RoborockClient
11
11
  from ..exceptions import RoborockConnectionException, RoborockException, VacuumError
12
12
  from ..protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
13
- from ..protocols.v1_protocol import encode_local_payload
13
+ from ..protocols.v1_protocol import RequestMessage
14
14
  from ..roborock_message import RoborockMessage, RoborockMessageProtocol
15
15
  from ..util import RoborockLoggerAdapter
16
16
  from .roborock_client_v1 import CLOUD_REQUIRED, RoborockClientV1
@@ -60,7 +60,7 @@ class RoborockLocalClientV1(RoborockClientV1, RoborockClient):
60
60
  self.transport: Transport | None = None
61
61
  self._mutex = Lock()
62
62
  self.keep_alive_task: TimerHandle | None = None
63
- RoborockClientV1.__init__(self, device_data, "abc")
63
+ RoborockClientV1.__init__(self, device_data, security_data=None)
64
64
  RoborockClient.__init__(self, device_data)
65
65
  self._local_protocol = _LocalProtocol(self._data_received, self._connection_lost)
66
66
  self._encoder: Encoder = create_local_encoder(device_data.device.local_key)
@@ -123,12 +123,20 @@ class RoborockLocalClientV1(RoborockClientV1, RoborockClient):
123
123
 
124
124
  async def hello(self):
125
125
  try:
126
- return await self._send_message(_HELLO_REQUEST_MESSAGE)
126
+ return await self._send_message(
127
+ roborock_message=_HELLO_REQUEST_MESSAGE,
128
+ request_id=_HELLO_REQUEST_MESSAGE.seq,
129
+ response_protocol=RoborockMessageProtocol.HELLO_RESPONSE,
130
+ )
127
131
  except Exception as e:
128
132
  self._logger.error(e)
129
133
 
130
134
  async def ping(self) -> None:
131
- await self._send_message(_PING_REQUEST_MESSAGE)
135
+ await self._send_message(
136
+ roborock_message=_PING_REQUEST_MESSAGE,
137
+ request_id=_PING_REQUEST_MESSAGE.seq,
138
+ response_protocol=RoborockMessageProtocol.PING_RESPONSE,
139
+ )
132
140
 
133
141
  def _send_msg_raw(self, data: bytes):
134
142
  try:
@@ -145,27 +153,26 @@ class RoborockLocalClientV1(RoborockClientV1, RoborockClient):
145
153
  ):
146
154
  if method in CLOUD_REQUIRED:
147
155
  raise RoborockException(f"Method {method} is not supported over local connection")
148
-
149
- roborock_message = encode_local_payload(method, params)
150
- self._logger.debug("Building message id %s for method %s", roborock_message.get_request_id(), method)
151
- return await self._send_message(roborock_message, method, params)
156
+ request_message = RequestMessage(method=method, params=params)
157
+ roborock_message = request_message.encode_message(RoborockMessageProtocol.GENERAL_REQUEST)
158
+ self._logger.debug("Building message id %s for method %s", request_message.request_id, method)
159
+ return await self._send_message(
160
+ roborock_message,
161
+ request_id=request_message.request_id,
162
+ response_protocol=RoborockMessageProtocol.GENERAL_REQUEST,
163
+ method=method,
164
+ params=params,
165
+ )
152
166
 
153
167
  async def _send_message(
154
168
  self,
155
169
  roborock_message: RoborockMessage,
170
+ request_id: int,
171
+ response_protocol: int,
156
172
  method: str | None = None,
157
173
  params: list | dict | int | None = None,
158
174
  ) -> RoborockMessage:
159
175
  await self.validate_connection()
160
- request_id: int | None
161
- if not method or not method.startswith("get"):
162
- request_id = roborock_message.seq
163
- response_protocol = request_id + 1
164
- else:
165
- request_id = roborock_message.get_request_id()
166
- response_protocol = RoborockMessageProtocol.GENERAL_REQUEST
167
- if request_id is None:
168
- raise RoborockException(f"Failed build message {roborock_message}")
169
176
  msg = self._encoder(roborock_message)
170
177
  if method:
171
178
  self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
@@ -1,4 +1,3 @@
1
- import base64
2
1
  import logging
3
2
 
4
3
  from vacuum_map_parser_base.config.color import ColorsPalette
@@ -10,8 +9,7 @@ from roborock.cloud_api import RoborockMqttClient
10
9
 
11
10
  from ..containers import DeviceData, UserData
12
11
  from ..exceptions import CommandVacuumError, RoborockException, VacuumError
13
- from ..protocol import Utils
14
- from ..protocols.v1_protocol import SecurityData, create_mqtt_payload_encoder
12
+ from ..protocols.v1_protocol import RequestMessage, create_security_data
15
13
  from ..roborock_message import (
16
14
  RoborockMessageProtocol,
17
15
  )
@@ -30,15 +28,12 @@ class RoborockMqttClientV1(RoborockMqttClient, RoborockClientV1):
30
28
  rriot = user_data.rriot
31
29
  if rriot is None:
32
30
  raise RoborockException("Got no rriot data from user_data")
33
- endpoint = base64.b64encode(Utils.md5(rriot.k.encode())[8:14]).decode()
34
-
35
31
  RoborockMqttClient.__init__(self, user_data, device_info)
36
- RoborockClientV1.__init__(self, device_info, endpoint)
32
+ security_data = create_security_data(rriot)
33
+ RoborockClientV1.__init__(self, device_info, security_data=security_data)
37
34
  self.queue_timeout = queue_timeout
38
35
  self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
39
- self._payload_encoder = create_mqtt_payload_encoder(
40
- SecurityData(endpoint=self._endpoint, nonce=self._nonce),
41
- )
36
+ self._security_data = security_data
42
37
 
43
38
  async def _send_command(
44
39
  self,
@@ -49,13 +44,15 @@ class RoborockMqttClientV1(RoborockMqttClient, RoborockClientV1):
49
44
  # When we have more custom commands do something more complicated here
50
45
  return await self._get_calibration_points()
51
46
 
52
- roborock_message = self._payload_encoder(method, params)
53
- self._logger.debug("Building message id %s for method %s", roborock_message.get_request_id, method)
47
+ request_message = RequestMessage(method=method, params=params)
48
+ roborock_message = request_message.encode_message(
49
+ RoborockMessageProtocol.RPC_REQUEST,
50
+ security_data=self._security_data,
51
+ )
52
+ self._logger.debug("Building message id %s for method %s", request_message.request_id, method)
54
53
 
55
54
  await self.validate_connection()
56
- request_id = roborock_message.get_request_id()
57
- if request_id is None:
58
- raise RoborockException(f"Failed build message {roborock_message}")
55
+ request_id = request_message.request_id
59
56
  response_protocol = (
60
57
  RoborockMessageProtocol.MAP_RESPONSE if method in COMMANDS_SECURED else RoborockMessageProtocol.RPC_RESPONSE
61
58
  )
@@ -1,43 +0,0 @@
1
- """Thin wrapper around the MQTT channel for Roborock A01 devices."""
2
-
3
- from __future__ import annotations
4
-
5
- import logging
6
- from typing import Any, overload
7
-
8
- from roborock.protocols.a01_protocol import (
9
- decode_rpc_response,
10
- encode_mqtt_payload,
11
- )
12
- from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol
13
-
14
- from .mqtt_channel import MqttChannel
15
-
16
- _LOGGER = logging.getLogger(__name__)
17
-
18
-
19
- @overload
20
- async def send_decoded_command(
21
- mqtt_channel: MqttChannel,
22
- params: dict[RoborockDyadDataProtocol, Any],
23
- ) -> dict[RoborockDyadDataProtocol, Any]:
24
- ...
25
-
26
-
27
- @overload
28
- async def send_decoded_command(
29
- mqtt_channel: MqttChannel,
30
- params: dict[RoborockZeoProtocol, Any],
31
- ) -> dict[RoborockZeoProtocol, Any]:
32
- ...
33
-
34
-
35
- async def send_decoded_command(
36
- mqtt_channel: MqttChannel,
37
- params: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any],
38
- ) -> dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any]:
39
- """Send a command on the MQTT channel and get a decoded response."""
40
- _LOGGER.debug("Sending MQTT command: %s", params)
41
- roborock_message = encode_mqtt_payload(params)
42
- response = await mqtt_channel.send_message(roborock_message)
43
- return decode_rpc_response(response) # type: ignore[return-value]
@@ -1,137 +0,0 @@
1
- """Modules for communicating with specific Roborock devices over MQTT."""
2
-
3
- import asyncio
4
- import logging
5
- from collections.abc import Callable
6
- from json import JSONDecodeError
7
-
8
- from roborock.containers import HomeDataDevice, RRiot, UserData
9
- from roborock.exceptions import RoborockException
10
- from roborock.mqtt.session import MqttParams, MqttSession
11
- from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder
12
- from roborock.roborock_message import RoborockMessage
13
-
14
- from .channel import Channel
15
-
16
- _LOGGER = logging.getLogger(__name__)
17
-
18
-
19
- class MqttChannel(Channel):
20
- """Simple RPC-style channel for communicating with a device over MQTT.
21
-
22
- Handles request/response correlation and timeouts, but leaves message
23
- format most parsing to higher-level components.
24
- """
25
-
26
- def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot: RRiot, mqtt_params: MqttParams):
27
- self._mqtt_session = mqtt_session
28
- self._duid = duid
29
- self._local_key = local_key
30
- self._rriot = rriot
31
- self._mqtt_params = mqtt_params
32
-
33
- # RPC support
34
- self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {}
35
- self._decoder = create_mqtt_decoder(local_key)
36
- self._encoder = create_mqtt_encoder(local_key)
37
- self._queue_lock = asyncio.Lock()
38
- self._mqtt_unsub: Callable[[], None] | None = None
39
-
40
- @property
41
- def is_connected(self) -> bool:
42
- """Return true if the channel is connected."""
43
- return (self._mqtt_unsub is not None) and self._mqtt_session.connected
44
-
45
- @property
46
- def _publish_topic(self) -> str:
47
- """Topic to send commands to the device."""
48
- return f"rr/m/i/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}"
49
-
50
- @property
51
- def _subscribe_topic(self) -> str:
52
- """Topic to receive responses from the device."""
53
- return f"rr/m/o/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}"
54
-
55
- async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
56
- """Subscribe to the device's response topic.
57
-
58
- The callback will be called with the message payload when a message is received.
59
-
60
- All messages received will be processed through the provided callback, even
61
- those sent in response to the `send_command` command.
62
-
63
- Returns a callable that can be used to unsubscribe from the topic.
64
- """
65
-
66
- def message_handler(payload: bytes) -> None:
67
- if not (messages := self._decoder(payload)):
68
- _LOGGER.warning("Failed to decode MQTT message: %s", payload)
69
- return
70
- for message in messages:
71
- _LOGGER.debug("Received message: %s", message)
72
- asyncio.create_task(self._resolve_future_with_lock(message))
73
- try:
74
- callback(message)
75
- except Exception as e:
76
- _LOGGER.exception("Uncaught error in message handler callback: %s", e)
77
-
78
- self._mqtt_unsub = await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
79
-
80
- def unsub_wrapper() -> None:
81
- if self._mqtt_unsub is not None:
82
- self._mqtt_unsub()
83
- self._mqtt_unsub = None
84
-
85
- return unsub_wrapper
86
-
87
- async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
88
- """Resolve waiting future with proper locking."""
89
- if (request_id := message.get_request_id()) is None:
90
- _LOGGER.debug("Received message with no request_id")
91
- return
92
- async with self._queue_lock:
93
- if (future := self._waiting_queue.pop(request_id, None)) is not None:
94
- future.set_result(message)
95
- else:
96
- _LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
97
-
98
- async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
99
- """Send a command message and wait for the response message.
100
-
101
- Returns the raw response message - caller is responsible for parsing.
102
- """
103
- try:
104
- if (request_id := message.get_request_id()) is None:
105
- raise RoborockException("Message must have a request_id for RPC calls")
106
- except (ValueError, JSONDecodeError) as err:
107
- _LOGGER.exception("Error getting request_id from message: %s", err)
108
- raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
109
-
110
- future: asyncio.Future[RoborockMessage] = asyncio.Future()
111
- async with self._queue_lock:
112
- if request_id in self._waiting_queue:
113
- raise RoborockException(f"Request ID {request_id} already pending, cannot send command")
114
- self._waiting_queue[request_id] = future
115
-
116
- try:
117
- encoded_msg = self._encoder(message)
118
- await self._mqtt_session.publish(self._publish_topic, encoded_msg)
119
-
120
- return await asyncio.wait_for(future, timeout=timeout)
121
-
122
- except asyncio.TimeoutError as ex:
123
- async with self._queue_lock:
124
- self._waiting_queue.pop(request_id, None)
125
- raise RoborockException(f"Command timed out after {timeout}s") from ex
126
- except Exception:
127
- logging.exception("Uncaught error sending command")
128
- async with self._queue_lock:
129
- self._waiting_queue.pop(request_id, None)
130
- raise
131
-
132
-
133
- def create_mqtt_channel(
134
- user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice
135
- ) -> MqttChannel:
136
- """Create a V1Channel for the given device."""
137
- return MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)