python-roborock 2.36.0__tar.gz → 2.37.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {python_roborock-2.36.0 → python_roborock-2.37.0}/PKG-INFO +1 -1
  2. {python_roborock-2.36.0 → python_roborock-2.37.0}/pyproject.toml +1 -1
  3. python_roborock-2.37.0/roborock/devices/a01_channel.py +93 -0
  4. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/b01_channel.py +2 -5
  5. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/local_channel.py +10 -28
  6. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/mqtt_channel.py +19 -51
  7. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/traits/b01/props.py +2 -3
  8. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/v1_channel.py +1 -1
  9. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/v1_rpc_channel.py +38 -10
  10. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/protocols/v1_protocol.py +30 -33
  11. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/roborock_message.py +0 -10
  12. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/version_1_apis/roborock_local_client_v1.py +23 -16
  13. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/version_1_apis/roborock_mqtt_client_v1.py +10 -8
  14. python_roborock-2.36.0/roborock/devices/a01_channel.py +0 -43
  15. python_roborock-2.36.0/roborock/devices/pending.py +0 -45
  16. {python_roborock-2.36.0 → python_roborock-2.37.0}/LICENSE +0 -0
  17. {python_roborock-2.36.0 → python_roborock-2.37.0}/README.md +0 -0
  18. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/__init__.py +0 -0
  19. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/api.py +0 -0
  20. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/clean_modes.py +0 -0
  21. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/cli.py +0 -0
  22. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/cloud_api.py +0 -0
  23. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/code_mappings.py +0 -0
  24. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/command_cache.py +0 -0
  25. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/const.py +0 -0
  26. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/containers.py +0 -0
  27. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/device_features.py +0 -0
  28. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/README.md +0 -0
  29. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/__init__.py +0 -0
  30. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/cache.py +0 -0
  31. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/channel.py +0 -0
  32. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/device.py +0 -0
  33. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/device_manager.py +0 -0
  34. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/traits/b01/__init__.py +0 -0
  35. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/traits/dyad.py +0 -0
  36. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/traits/status.py +0 -0
  37. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/traits/trait.py +0 -0
  38. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/devices/traits/zeo.py +0 -0
  39. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/exceptions.py +0 -0
  40. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/mqtt/__init__.py +0 -0
  41. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/mqtt/roborock_session.py +0 -0
  42. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/mqtt/session.py +0 -0
  43. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/protocol.py +0 -0
  44. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/protocols/a01_protocol.py +0 -0
  45. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/protocols/b01_protocol.py +0 -0
  46. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/py.typed +0 -0
  47. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/roborock_future.py +0 -0
  48. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/roborock_typing.py +0 -0
  49. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/util.py +0 -0
  50. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/version_1_apis/__init__.py +0 -0
  51. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/version_1_apis/roborock_client_v1.py +0 -0
  52. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/version_a01_apis/__init__.py +0 -0
  53. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/version_a01_apis/roborock_client_a01.py +0 -0
  54. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/version_a01_apis/roborock_mqtt_client_a01.py +0 -0
  55. {python_roborock-2.36.0 → python_roborock-2.37.0}/roborock/web_api.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: python-roborock
3
- Version: 2.36.0
3
+ Version: 2.37.0
4
4
  Summary: A package to control Roborock vacuums.
5
5
  Home-page: https://github.com/humbertogontijo/python-roborock
6
6
  License: GPL-3.0-only
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "python-roborock"
3
- version = "2.36.0"
3
+ version = "2.37.0"
4
4
  description = "A package to control Roborock vacuums."
5
5
  authors = ["humbertogontijo <humbertogontijo@users.noreply.github.com>"]
6
6
  license = "GPL-3.0-only"
@@ -0,0 +1,93 @@
1
+ """Thin wrapper around the MQTT channel for Roborock A01 devices."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from typing import Any, overload
6
+
7
+ from roborock.exceptions import RoborockException
8
+ from roborock.protocols.a01_protocol import (
9
+ decode_rpc_response,
10
+ encode_mqtt_payload,
11
+ )
12
+ from roborock.roborock_message import (
13
+ RoborockDyadDataProtocol,
14
+ RoborockMessage,
15
+ RoborockZeoProtocol,
16
+ )
17
+
18
+ from .mqtt_channel import MqttChannel
19
+
20
+ _LOGGER = logging.getLogger(__name__)
21
+ _TIMEOUT = 10.0
22
+
23
+ # Both RoborockDyadDataProtocol and RoborockZeoProtocol have the same
24
+ # value for ID_QUERY
25
+ _ID_QUERY = int(RoborockDyadDataProtocol.ID_QUERY)
26
+
27
+
28
+ @overload
29
+ async def send_decoded_command(
30
+ mqtt_channel: MqttChannel,
31
+ params: dict[RoborockDyadDataProtocol, Any],
32
+ ) -> dict[RoborockDyadDataProtocol, Any]:
33
+ ...
34
+
35
+
36
+ @overload
37
+ async def send_decoded_command(
38
+ mqtt_channel: MqttChannel,
39
+ params: dict[RoborockZeoProtocol, Any],
40
+ ) -> dict[RoborockZeoProtocol, Any]:
41
+ ...
42
+
43
+
44
+ async def send_decoded_command(
45
+ mqtt_channel: MqttChannel,
46
+ params: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any],
47
+ ) -> dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any]:
48
+ """Send a command on the MQTT channel and get a decoded response."""
49
+ _LOGGER.debug("Sending MQTT command: %s", params)
50
+ roborock_message = encode_mqtt_payload(params)
51
+
52
+ # For commands that set values: send the command and do not
53
+ # block waiting for a response. Queries are handled below.
54
+ param_values = {int(k): v for k, v in params.items()}
55
+ if not (query_values := param_values.get(_ID_QUERY)):
56
+ await mqtt_channel.publish(roborock_message)
57
+ return {}
58
+
59
+ # Merge any results together than contain the requested data. This
60
+ # does not use a future since it needs to merge results across responses.
61
+ # This could be simplified if we can assume there is a single response.
62
+ finished = asyncio.Event()
63
+ result: dict[int, Any] = {}
64
+
65
+ def find_response(response_message: RoborockMessage) -> None:
66
+ """Handle incoming messages and resolve the future."""
67
+ try:
68
+ decoded = decode_rpc_response(response_message)
69
+ except RoborockException as ex:
70
+ _LOGGER.info("Failed to decode a01 message: %s: %s", response_message, ex)
71
+ return
72
+ for key, value in decoded.items():
73
+ if key in query_values:
74
+ result[key] = value
75
+ if len(result) != len(query_values):
76
+ _LOGGER.debug("Incomplete query response: %s != %s", result, query_values)
77
+ return
78
+ _LOGGER.debug("Received query response: %s", result)
79
+ if not finished.is_set():
80
+ finished.set()
81
+
82
+ unsub = await mqtt_channel.subscribe(find_response)
83
+
84
+ try:
85
+ await mqtt_channel.publish(roborock_message)
86
+ try:
87
+ await asyncio.wait_for(finished.wait(), timeout=_TIMEOUT)
88
+ except TimeoutError as ex:
89
+ raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
90
+ finally:
91
+ unsub()
92
+
93
+ return result # type: ignore[return-value]
@@ -3,12 +3,10 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import logging
6
- from typing import Any
7
6
 
8
7
  from roborock.protocols.b01_protocol import (
9
8
  CommandType,
10
9
  ParamsType,
11
- decode_rpc_response,
12
10
  encode_mqtt_payload,
13
11
  )
14
12
 
@@ -22,9 +20,8 @@ async def send_decoded_command(
22
20
  dps: int,
23
21
  command: CommandType,
24
22
  params: ParamsType,
25
- ) -> dict[int, Any]:
23
+ ) -> None:
26
24
  """Send a command on the MQTT channel and get a decoded response."""
27
25
  _LOGGER.debug("Sending MQTT command: %s", params)
28
26
  roborock_message = encode_mqtt_payload(dps, command, params)
29
- response = await mqtt_channel.send_message(roborock_message)
30
- return decode_rpc_response(response) # type: ignore[return-value]
27
+ await mqtt_channel.publish(roborock_message)
@@ -4,14 +4,12 @@ import asyncio
4
4
  import logging
5
5
  from collections.abc import Callable
6
6
  from dataclasses import dataclass
7
- from json import JSONDecodeError
8
7
 
9
8
  from roborock.exceptions import RoborockConnectionException, RoborockException
10
9
  from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
11
10
  from roborock.roborock_message import RoborockMessage
12
11
 
13
12
  from .channel import Channel
14
- from .pending import PendingRpcs
15
13
 
16
14
  _LOGGER = logging.getLogger(__name__)
17
15
  _PORT = 58867
@@ -47,8 +45,6 @@ class LocalChannel(Channel):
47
45
  self._subscribers: list[Callable[[RoborockMessage], None]] = []
48
46
  self._is_connected = False
49
47
 
50
- # RPC support
51
- self._pending_rpcs: PendingRpcs[int, RoborockMessage] = PendingRpcs()
52
48
  self._decoder: Decoder = create_local_decoder(local_key)
53
49
  self._encoder: Encoder = create_local_encoder(local_key)
54
50
 
@@ -87,7 +83,6 @@ class LocalChannel(Channel):
87
83
  return
88
84
  for message in messages:
89
85
  _LOGGER.debug("Received message: %s", message)
90
- asyncio.create_task(self._resolve_future_with_lock(message))
91
86
  for callback in self._subscribers:
92
87
  try:
93
88
  callback(message)
@@ -109,37 +104,24 @@ class LocalChannel(Channel):
109
104
 
110
105
  return unsubscribe
111
106
 
112
- async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
113
- """Resolve waiting future with proper locking."""
114
- if (request_id := message.get_request_id()) is None:
115
- _LOGGER.debug("Received message with no request_id")
116
- return
117
- await self._pending_rpcs.resolve(request_id, message)
107
+ async def publish(self, message: RoborockMessage) -> None:
108
+ """Send a command message.
118
109
 
119
- async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
120
- """Send a command message and wait for the response message."""
110
+ The caller is responsible for associating the message with its response.
111
+ """
121
112
  if not self._transport or not self._is_connected:
122
113
  raise RoborockConnectionException("Not connected to device")
123
114
 
124
- try:
125
- if (request_id := message.get_request_id()) is None:
126
- raise RoborockException("Message must have a request_id for RPC calls")
127
- except (ValueError, JSONDecodeError) as err:
128
- _LOGGER.exception("Error getting request_id from message: %s", err)
129
- raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
130
-
131
- future: asyncio.Future[RoborockMessage] = await self._pending_rpcs.start(request_id)
132
115
  try:
133
116
  encoded_msg = self._encoder(message)
117
+ except Exception as err:
118
+ _LOGGER.exception("Error encoding MQTT message: %s", err)
119
+ raise RoborockException(f"Failed to encode MQTT message: {err}") from err
120
+ try:
134
121
  self._transport.write(encoded_msg)
135
- return await asyncio.wait_for(future, timeout=timeout)
136
- except asyncio.TimeoutError as ex:
137
- await self._pending_rpcs.pop(request_id)
138
- raise RoborockException(f"Command timed out after {timeout}s") from ex
139
- except Exception:
122
+ except Exception as err:
140
123
  logging.exception("Uncaught error sending command")
141
- await self._pending_rpcs.pop(request_id)
142
- raise
124
+ raise RoborockException(f"Failed to send message: {message}") from err
143
125
 
144
126
 
145
127
  # This module provides a factory function to create LocalChannel instances.
@@ -1,18 +1,15 @@
1
1
  """Modules for communicating with specific Roborock devices over MQTT."""
2
2
 
3
- import asyncio
4
3
  import logging
5
4
  from collections.abc import Callable
6
- from json import JSONDecodeError
7
5
 
8
6
  from roborock.containers import HomeDataDevice, RRiot, UserData
9
7
  from roborock.exceptions import RoborockException
10
- from roborock.mqtt.session import MqttParams, MqttSession
8
+ from roborock.mqtt.session import MqttParams, MqttSession, MqttSessionException
11
9
  from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder
12
10
  from roborock.roborock_message import RoborockMessage
13
11
 
14
12
  from .channel import Channel
15
- from .pending import PendingRpcs
16
13
 
17
14
  _LOGGER = logging.getLogger(__name__)
18
15
 
@@ -31,16 +28,16 @@ class MqttChannel(Channel):
31
28
  self._rriot = rriot
32
29
  self._mqtt_params = mqtt_params
33
30
 
34
- # RPC support
35
- self._pending_rpcs: PendingRpcs[int, RoborockMessage] = PendingRpcs()
36
31
  self._decoder = create_mqtt_decoder(local_key)
37
32
  self._encoder = create_mqtt_encoder(local_key)
38
- self._mqtt_unsub: Callable[[], None] | None = None
39
33
 
40
34
  @property
41
35
  def is_connected(self) -> bool:
42
- """Return true if the channel is connected."""
43
- return (self._mqtt_unsub is not None) and self._mqtt_session.connected
36
+ """Return true if the channel is connected.
37
+
38
+ This passes through the underlying MQTT session's connected state.
39
+ """
40
+ return self._mqtt_session.connected
44
41
 
45
42
  @property
46
43
  def _publish_topic(self) -> str:
@@ -57,9 +54,6 @@ class MqttChannel(Channel):
57
54
 
58
55
  The callback will be called with the message payload when a message is received.
59
56
 
60
- All messages received will be processed through the provided callback, even
61
- those sent in response to the `send_command` command.
62
-
63
57
  Returns a callable that can be used to unsubscribe from the topic.
64
58
  """
65
59
 
@@ -69,55 +63,29 @@ class MqttChannel(Channel):
69
63
  return
70
64
  for message in messages:
71
65
  _LOGGER.debug("Received message: %s", message)
72
- asyncio.create_task(self._resolve_future_with_lock(message))
73
66
  try:
74
67
  callback(message)
75
68
  except Exception as e:
76
69
  _LOGGER.exception("Uncaught error in message handler callback: %s", e)
77
70
 
78
- self._mqtt_unsub = await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
79
-
80
- def unsub_wrapper() -> None:
81
- if self._mqtt_unsub is not None:
82
- self._mqtt_unsub()
83
- self._mqtt_unsub = None
84
-
85
- return unsub_wrapper
86
-
87
- async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
88
- """Resolve waiting future with proper locking."""
89
- if (request_id := message.get_request_id()) is None:
90
- _LOGGER.debug("Received message with no request_id")
91
- return
92
- await self._pending_rpcs.resolve(request_id, message)
71
+ return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
93
72
 
94
- async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
95
- """Send a command message and wait for the response message.
73
+ async def publish(self, message: RoborockMessage) -> None:
74
+ """Publish a command message.
96
75
 
97
- Returns the raw response message - caller is responsible for parsing.
76
+ The caller is responsible for handling any responses and associating them
77
+ with the incoming request.
98
78
  """
99
- try:
100
- if (request_id := message.get_request_id()) is None:
101
- raise RoborockException("Message must have a request_id for RPC calls")
102
- except (ValueError, JSONDecodeError) as err:
103
- _LOGGER.exception("Error getting request_id from message: %s", err)
104
- raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
105
-
106
- future: asyncio.Future[RoborockMessage] = await self._pending_rpcs.start(request_id)
107
-
108
79
  try:
109
80
  encoded_msg = self._encoder(message)
110
- await self._mqtt_session.publish(self._publish_topic, encoded_msg)
111
-
112
- return await asyncio.wait_for(future, timeout=timeout)
113
-
114
- except asyncio.TimeoutError as ex:
115
- await self._pending_rpcs.pop(request_id)
116
- raise RoborockException(f"Command timed out after {timeout}s") from ex
117
- except Exception:
118
- logging.exception("Uncaught error sending command")
119
- await self._pending_rpcs.pop(request_id)
120
- raise
81
+ except Exception as e:
82
+ _LOGGER.exception("Error encoding MQTT message: %s", e)
83
+ raise RoborockException(f"Failed to encode MQTT message: {e}") from e
84
+ try:
85
+ return await self._mqtt_session.publish(self._publish_topic, encoded_msg)
86
+ except MqttSessionException as e:
87
+ _LOGGER.exception("Error publishing MQTT message: %s", e)
88
+ raise RoborockException(f"Failed to publish MQTT message: {e}") from e
121
89
 
122
90
 
123
91
  def create_mqtt_channel(
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import Any
5
4
 
6
5
  from roborock import RoborockB01Methods
7
6
  from roborock.roborock_message import RoborockB01Props
@@ -26,6 +25,6 @@ class B01PropsApi(Trait):
26
25
  """Initialize the B01Props API."""
27
26
  self._channel = channel
28
27
 
29
- async def query_values(self, props: list[RoborockB01Props]) -> dict[int, Any]:
28
+ async def query_values(self, props: list[RoborockB01Props]) -> None:
30
29
  """Query the device for the values of the given Dyad protocols."""
31
- return await send_decoded_command(self._channel, dps=10000, command=RoborockB01Methods.GET_PROP, params=props)
30
+ await send_decoded_command(self._channel, dps=10000, command=RoborockB01Methods.GET_PROP, params=props)
@@ -79,7 +79,7 @@ class V1Channel(Channel):
79
79
  @property
80
80
  def is_mqtt_connected(self) -> bool:
81
81
  """Return whether MQTT connection is available."""
82
- return self._mqtt_unsub is not None
82
+ return self._mqtt_unsub is not None and self._mqtt_channel.is_connected
83
83
 
84
84
  @property
85
85
  def rpc_channel(self) -> V1RpcChannel:
@@ -6,25 +6,27 @@ a simple interface for sending commands and receiving responses over both MQTT
6
6
  and local connections, preferring local when available.
7
7
  """
8
8
 
9
+ import asyncio
9
10
  import logging
10
11
  from collections.abc import Callable
11
12
  from typing import Any, Protocol, TypeVar, overload
12
13
 
13
14
  from roborock.containers import RoborockBase
15
+ from roborock.exceptions import RoborockException
14
16
  from roborock.protocols.v1_protocol import (
15
17
  CommandType,
16
18
  ParamsType,
19
+ RequestMessage,
17
20
  SecurityData,
18
- create_mqtt_payload_encoder,
19
21
  decode_rpc_response,
20
- encode_local_payload,
21
22
  )
22
- from roborock.roborock_message import RoborockMessage
23
+ from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
23
24
 
24
25
  from .local_channel import LocalChannel
25
26
  from .mqtt_channel import MqttChannel
26
27
 
27
28
  _LOGGER = logging.getLogger(__name__)
29
+ _TIMEOUT = 10.0
28
30
 
29
31
 
30
32
  _T = TypeVar("_T", bound=RoborockBase)
@@ -116,7 +118,7 @@ class PayloadEncodedV1RpcChannel(BaseV1RpcChannel):
116
118
  self,
117
119
  name: str,
118
120
  channel: MqttChannel | LocalChannel,
119
- payload_encoder: Callable[[CommandType, ParamsType], RoborockMessage],
121
+ payload_encoder: Callable[[RequestMessage], RoborockMessage],
120
122
  ) -> None:
121
123
  """Initialize the channel with a raw channel and an encoder function."""
122
124
  self._name = name
@@ -131,18 +133,44 @@ class PayloadEncodedV1RpcChannel(BaseV1RpcChannel):
131
133
  ) -> Any:
132
134
  """Send a command and return a parsed response RoborockBase type."""
133
135
  _LOGGER.debug("Sending command (%s): %s, params=%s", self._name, method, params)
134
- message = self._payload_encoder(method, params)
135
- response = await self._channel.send_message(message)
136
- return decode_rpc_response(response)
136
+ request_message = RequestMessage(method, params=params)
137
+ message = self._payload_encoder(request_message)
138
+
139
+ future: asyncio.Future[dict[str, Any]] = asyncio.Future()
140
+
141
+ def find_response(response_message: RoborockMessage) -> None:
142
+ try:
143
+ decoded = decode_rpc_response(response_message)
144
+ except RoborockException:
145
+ return
146
+ if decoded.request_id == request_message.request_id:
147
+ future.set_result(decoded.data)
148
+
149
+ unsub = await self._channel.subscribe(find_response)
150
+ try:
151
+ await self._channel.publish(message)
152
+ return await asyncio.wait_for(future, timeout=_TIMEOUT)
153
+ except TimeoutError as ex:
154
+ future.cancel()
155
+ raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
156
+ finally:
157
+ unsub()
137
158
 
138
159
 
139
160
  def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel:
140
161
  """Create a V1 RPC channel using an MQTT channel."""
141
- payload_encoder = create_mqtt_payload_encoder(security_data)
142
- return PayloadEncodedV1RpcChannel("mqtt", mqtt_channel, payload_encoder)
162
+ return PayloadEncodedV1RpcChannel(
163
+ "mqtt",
164
+ mqtt_channel,
165
+ lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data),
166
+ )
143
167
 
144
168
 
145
169
  def create_combined_rpc_channel(local_channel: LocalChannel, mqtt_rpc_channel: V1RpcChannel) -> V1RpcChannel:
146
170
  """Create a V1 RPC channel that combines local and MQTT channels."""
147
- local_rpc_channel = PayloadEncodedV1RpcChannel("local", local_channel, encode_local_payload)
171
+ local_rpc_channel = PayloadEncodedV1RpcChannel(
172
+ "local",
173
+ local_channel,
174
+ lambda x: x.encode_message(RoborockMessageProtocol.GENERAL_REQUEST),
175
+ )
148
176
  return CombinedV1RpcChannel(local_channel, local_rpc_channel, mqtt_rpc_channel)
@@ -25,8 +25,6 @@ _LOGGER = logging.getLogger(__name__)
25
25
  __all__ = [
26
26
  "SecurityData",
27
27
  "create_security_data",
28
- "create_mqtt_payload_encoder",
29
- "encode_local_payload",
30
28
  "decode_rpc_response",
31
29
  ]
32
30
 
@@ -66,7 +64,19 @@ class RequestMessage:
66
64
  timestamp: int = field(default_factory=lambda: math.floor(time.time()))
67
65
  request_id: int = field(default_factory=lambda: get_next_int(10000, 32767))
68
66
 
69
- def as_payload(self, security_data: SecurityData | None) -> bytes:
67
+ def encode_message(
68
+ self,
69
+ protocol: RoborockMessageProtocol,
70
+ security_data: SecurityData | None = None,
71
+ ) -> RoborockMessage:
72
+ """Convert the request message to a RoborockMessage."""
73
+ return RoborockMessage(
74
+ timestamp=self.timestamp,
75
+ protocol=protocol,
76
+ payload=self._as_payload(security_data=security_data),
77
+ )
78
+
79
+ def _as_payload(self, security_data: SecurityData | None) -> bytes:
70
80
  """Convert the request arguments to a dictionary."""
71
81
  inner = {
72
82
  "id": self.request_id,
@@ -85,36 +95,18 @@ class RequestMessage:
85
95
  )
86
96
 
87
97
 
88
- def create_mqtt_payload_encoder(security_data: SecurityData) -> Callable[[CommandType, ParamsType], RoborockMessage]:
89
- """Create a payload encoder for V1 commands over MQTT."""
90
-
91
- def _get_payload(method: CommandType, params: ParamsType) -> RoborockMessage:
92
- """Build the payload for a V1 command."""
93
- request = RequestMessage(method=method, params=params)
94
- payload = request.as_payload(security_data) # always secure
95
- return RoborockMessage(
96
- timestamp=request.timestamp,
97
- protocol=RoborockMessageProtocol.RPC_REQUEST,
98
- payload=payload,
99
- )
100
-
101
- return _get_payload
102
-
98
+ @dataclass(kw_only=True, frozen=True)
99
+ class ResponseMessage:
100
+ """Data structure for v1 RoborockMessage responses."""
103
101
 
104
- def encode_local_payload(method: CommandType, params: ParamsType) -> RoborockMessage:
105
- """Encode payload for V1 commands over local connection."""
102
+ request_id: int | None
103
+ """The request ID of the response."""
106
104
 
107
- request = RequestMessage(method=method, params=params)
108
- payload = request.as_payload(security_data=None)
105
+ data: dict[str, Any]
106
+ """The data of the response."""
109
107
 
110
- return RoborockMessage(
111
- timestamp=request.timestamp,
112
- protocol=RoborockMessageProtocol.GENERAL_REQUEST,
113
- payload=payload,
114
- )
115
108
 
116
-
117
- def decode_rpc_response(message: RoborockMessage) -> dict[str, Any]:
109
+ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage:
118
110
  """Decode a V1 RPC_RESPONSE message."""
119
111
  if not message.payload:
120
112
  raise RoborockException("Invalid V1 message format: missing payload")
@@ -128,14 +120,19 @@ def decode_rpc_response(message: RoborockMessage) -> dict[str, Any]:
128
120
  if not isinstance(datapoints, dict):
129
121
  raise RoborockException(f"Invalid V1 message format: 'dps' should be a dictionary for {message.payload!r}")
130
122
 
131
- if not (data_point := datapoints.get("102")):
132
- raise RoborockException("Invalid V1 message format: missing '102' data point")
123
+ if not (data_point := datapoints.get(str(RoborockMessageProtocol.RPC_RESPONSE))):
124
+ raise RoborockException(
125
+ f"Invalid V1 message format: missing '{RoborockMessageProtocol.RPC_RESPONSE}' data point"
126
+ )
133
127
 
134
128
  try:
135
129
  data_point_response = json.loads(data_point)
136
130
  except (json.JSONDecodeError, TypeError) as e:
137
- raise RoborockException(f"Invalid V1 message data point '102': {e} for {message.payload!r}") from e
131
+ raise RoborockException(
132
+ f"Invalid V1 message data point '{RoborockMessageProtocol.RPC_RESPONSE}': {e} for {message.payload!r}"
133
+ ) from e
138
134
 
135
+ request_id: int | None = data_point_response.get("id")
139
136
  if error := data_point_response.get("error"):
140
137
  raise RoborockException(f"Error in message: {error}")
141
138
 
@@ -146,7 +143,7 @@ def decode_rpc_response(message: RoborockMessage) -> dict[str, Any]:
146
143
  result = result[0]
147
144
  if not isinstance(result, dict):
148
145
  raise RoborockException(f"Invalid V1 message format: 'result' should be a dictionary for {message.payload!r}")
149
- return result
146
+ return ResponseMessage(request_id=request_id, data=result)
150
147
 
151
148
 
152
149
  @dataclass
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import json
4
3
  import math
5
4
  import time
6
5
  from dataclasses import dataclass, field
@@ -247,12 +246,3 @@ class RoborockMessage:
247
246
  version: bytes = b"1.0"
248
247
  random: int = field(default_factory=lambda: get_next_int(10000, 99999))
249
248
  timestamp: int = field(default_factory=lambda: math.floor(time.time()))
250
-
251
- def get_request_id(self) -> int | None:
252
- if self.payload:
253
- payload = json.loads(self.payload.decode())
254
- for data_point_number, data_point in payload.get("dps").items():
255
- if data_point_number in ["101", "102"]:
256
- data_point_response = json.loads(data_point)
257
- return data_point_response.get("id")
258
- return None
@@ -10,7 +10,7 @@ from .. import CommandVacuumError, DeviceData, RoborockCommand
10
10
  from ..api import RoborockClient
11
11
  from ..exceptions import RoborockConnectionException, RoborockException, VacuumError
12
12
  from ..protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
13
- from ..protocols.v1_protocol import encode_local_payload
13
+ from ..protocols.v1_protocol import RequestMessage
14
14
  from ..roborock_message import RoborockMessage, RoborockMessageProtocol
15
15
  from ..util import RoborockLoggerAdapter
16
16
  from .roborock_client_v1 import CLOUD_REQUIRED, RoborockClientV1
@@ -123,12 +123,20 @@ class RoborockLocalClientV1(RoborockClientV1, RoborockClient):
123
123
 
124
124
  async def hello(self):
125
125
  try:
126
- return await self._send_message(_HELLO_REQUEST_MESSAGE)
126
+ return await self._send_message(
127
+ roborock_message=_HELLO_REQUEST_MESSAGE,
128
+ request_id=_HELLO_REQUEST_MESSAGE.seq,
129
+ response_protocol=RoborockMessageProtocol.HELLO_RESPONSE,
130
+ )
127
131
  except Exception as e:
128
132
  self._logger.error(e)
129
133
 
130
134
  async def ping(self) -> None:
131
- await self._send_message(_PING_REQUEST_MESSAGE)
135
+ await self._send_message(
136
+ roborock_message=_PING_REQUEST_MESSAGE,
137
+ request_id=_PING_REQUEST_MESSAGE.seq,
138
+ response_protocol=RoborockMessageProtocol.PING_RESPONSE,
139
+ )
132
140
 
133
141
  def _send_msg_raw(self, data: bytes):
134
142
  try:
@@ -145,27 +153,26 @@ class RoborockLocalClientV1(RoborockClientV1, RoborockClient):
145
153
  ):
146
154
  if method in CLOUD_REQUIRED:
147
155
  raise RoborockException(f"Method {method} is not supported over local connection")
148
-
149
- roborock_message = encode_local_payload(method, params)
150
- self._logger.debug("Building message id %s for method %s", roborock_message.get_request_id(), method)
151
- return await self._send_message(roborock_message, method, params)
156
+ request_message = RequestMessage(method=method, params=params)
157
+ roborock_message = request_message.encode_message(RoborockMessageProtocol.GENERAL_REQUEST)
158
+ self._logger.debug("Building message id %s for method %s", request_message.request_id, method)
159
+ return await self._send_message(
160
+ roborock_message,
161
+ request_id=request_message.request_id,
162
+ response_protocol=RoborockMessageProtocol.GENERAL_REQUEST,
163
+ method=method,
164
+ params=params,
165
+ )
152
166
 
153
167
  async def _send_message(
154
168
  self,
155
169
  roborock_message: RoborockMessage,
170
+ request_id: int,
171
+ response_protocol: int,
156
172
  method: str | None = None,
157
173
  params: list | dict | int | None = None,
158
174
  ) -> RoborockMessage:
159
175
  await self.validate_connection()
160
- request_id: int | None
161
- if not method or not method.startswith("get"):
162
- request_id = roborock_message.seq
163
- response_protocol = request_id + 1
164
- else:
165
- request_id = roborock_message.get_request_id()
166
- response_protocol = RoborockMessageProtocol.GENERAL_REQUEST
167
- if request_id is None:
168
- raise RoborockException(f"Failed build message {roborock_message}")
169
176
  msg = self._encoder(roborock_message)
170
177
  if method:
171
178
  self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
@@ -9,7 +9,7 @@ from roborock.cloud_api import RoborockMqttClient
9
9
 
10
10
  from ..containers import DeviceData, UserData
11
11
  from ..exceptions import CommandVacuumError, RoborockException, VacuumError
12
- from ..protocols.v1_protocol import create_mqtt_payload_encoder, create_security_data
12
+ from ..protocols.v1_protocol import RequestMessage, create_security_data
13
13
  from ..roborock_message import (
14
14
  RoborockMessageProtocol,
15
15
  )
@@ -28,12 +28,12 @@ class RoborockMqttClientV1(RoborockMqttClient, RoborockClientV1):
28
28
  rriot = user_data.rriot
29
29
  if rriot is None:
30
30
  raise RoborockException("Got no rriot data from user_data")
31
- security_data = create_security_data(rriot)
32
31
  RoborockMqttClient.__init__(self, user_data, device_info)
32
+ security_data = create_security_data(rriot)
33
33
  RoborockClientV1.__init__(self, device_info, security_data=security_data)
34
34
  self.queue_timeout = queue_timeout
35
35
  self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
36
- self._payload_encoder = create_mqtt_payload_encoder(security_data)
36
+ self._security_data = security_data
37
37
 
38
38
  async def _send_command(
39
39
  self,
@@ -44,13 +44,15 @@ class RoborockMqttClientV1(RoborockMqttClient, RoborockClientV1):
44
44
  # When we have more custom commands do something more complicated here
45
45
  return await self._get_calibration_points()
46
46
 
47
- roborock_message = self._payload_encoder(method, params)
48
- self._logger.debug("Building message id %s for method %s", roborock_message.get_request_id, method)
47
+ request_message = RequestMessage(method=method, params=params)
48
+ roborock_message = request_message.encode_message(
49
+ RoborockMessageProtocol.RPC_REQUEST,
50
+ security_data=self._security_data,
51
+ )
52
+ self._logger.debug("Building message id %s for method %s", request_message.request_id, method)
49
53
 
50
54
  await self.validate_connection()
51
- request_id = roborock_message.get_request_id()
52
- if request_id is None:
53
- raise RoborockException(f"Failed build message {roborock_message}")
55
+ request_id = request_message.request_id
54
56
  response_protocol = (
55
57
  RoborockMessageProtocol.MAP_RESPONSE if method in COMMANDS_SECURED else RoborockMessageProtocol.RPC_RESPONSE
56
58
  )
@@ -1,43 +0,0 @@
1
- """Thin wrapper around the MQTT channel for Roborock A01 devices."""
2
-
3
- from __future__ import annotations
4
-
5
- import logging
6
- from typing import Any, overload
7
-
8
- from roborock.protocols.a01_protocol import (
9
- decode_rpc_response,
10
- encode_mqtt_payload,
11
- )
12
- from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol
13
-
14
- from .mqtt_channel import MqttChannel
15
-
16
- _LOGGER = logging.getLogger(__name__)
17
-
18
-
19
- @overload
20
- async def send_decoded_command(
21
- mqtt_channel: MqttChannel,
22
- params: dict[RoborockDyadDataProtocol, Any],
23
- ) -> dict[RoborockDyadDataProtocol, Any]:
24
- ...
25
-
26
-
27
- @overload
28
- async def send_decoded_command(
29
- mqtt_channel: MqttChannel,
30
- params: dict[RoborockZeoProtocol, Any],
31
- ) -> dict[RoborockZeoProtocol, Any]:
32
- ...
33
-
34
-
35
- async def send_decoded_command(
36
- mqtt_channel: MqttChannel,
37
- params: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any],
38
- ) -> dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any]:
39
- """Send a command on the MQTT channel and get a decoded response."""
40
- _LOGGER.debug("Sending MQTT command: %s", params)
41
- roborock_message = encode_mqtt_payload(params)
42
- response = await mqtt_channel.send_message(roborock_message)
43
- return decode_rpc_response(response) # type: ignore[return-value]
@@ -1,45 +0,0 @@
1
- """Module for managing pending RPCs."""
2
-
3
- import asyncio
4
- import logging
5
- from typing import Generic, TypeVar
6
-
7
- from roborock.exceptions import RoborockException
8
-
9
- _LOGGER = logging.getLogger(__name__)
10
-
11
-
12
- K = TypeVar("K")
13
- V = TypeVar("V")
14
-
15
-
16
- class PendingRpcs(Generic[K, V]):
17
- """Manage pending RPCs."""
18
-
19
- def __init__(self) -> None:
20
- """Initialize the pending RPCs."""
21
- self._queue_lock = asyncio.Lock()
22
- self._waiting_queue: dict[K, asyncio.Future[V]] = {}
23
-
24
- async def start(self, key: K) -> asyncio.Future[V]:
25
- """Start the pending RPCs."""
26
- future: asyncio.Future[V] = asyncio.Future()
27
- async with self._queue_lock:
28
- if key in self._waiting_queue:
29
- raise RoborockException(f"Request ID {key} already pending, cannot send command")
30
- self._waiting_queue[key] = future
31
- return future
32
-
33
- async def pop(self, key: K) -> None:
34
- """Pop a pending RPC."""
35
- async with self._queue_lock:
36
- if (future := self._waiting_queue.pop(key, None)) is not None:
37
- future.cancel()
38
-
39
- async def resolve(self, key: K, value: V) -> None:
40
- """Resolve waiting future with proper locking."""
41
- async with self._queue_lock:
42
- if (future := self._waiting_queue.pop(key, None)) is not None:
43
- future.set_result(value)
44
- else:
45
- _LOGGER.debug("Received unsolicited message: %s", key)