ohmqtt 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ohmqtt/__init__.py +0 -0
- ohmqtt/client.py +266 -0
- ohmqtt/connection/__init__.py +117 -0
- ohmqtt/connection/address.py +91 -0
- ohmqtt/connection/decoder.py +132 -0
- ohmqtt/connection/fsm.py +179 -0
- ohmqtt/connection/handlers.py +86 -0
- ohmqtt/connection/keepalive.py +139 -0
- ohmqtt/connection/selector.py +113 -0
- ohmqtt/connection/states.py +468 -0
- ohmqtt/connection/timeout.py +41 -0
- ohmqtt/connection/types.py +98 -0
- ohmqtt/error.py +10 -0
- ohmqtt/logger.py +14 -0
- ohmqtt/mqtt_spec.py +110 -0
- ohmqtt/packet/__init__.py +78 -0
- ohmqtt/packet/auth.py +57 -0
- ohmqtt/packet/base.py +29 -0
- ohmqtt/packet/connect.py +252 -0
- ohmqtt/packet/ping.py +53 -0
- ohmqtt/packet/publish.py +253 -0
- ohmqtt/packet/subscribe.py +200 -0
- ohmqtt/persistence/__init__.py +0 -0
- ohmqtt/persistence/base.py +118 -0
- ohmqtt/persistence/in_memory.py +142 -0
- ohmqtt/persistence/sqlite.py +275 -0
- ohmqtt/platform.py +11 -0
- ohmqtt/property.py +329 -0
- ohmqtt/protected.py +53 -0
- ohmqtt/py.typed +0 -0
- ohmqtt/serialization.py +181 -0
- ohmqtt/session.py +224 -0
- ohmqtt/subscriptions.py +429 -0
- ohmqtt/topic_alias.py +135 -0
- ohmqtt/topic_filter.py +91 -0
- ohmqtt-0.0.3.dist-info/METADATA +103 -0
- ohmqtt-0.0.3.dist-info/RECORD +40 -0
- ohmqtt-0.0.3.dist-info/WHEEL +5 -0
- ohmqtt-0.0.3.dist-info/licenses/LICENSE +7 -0
- ohmqtt-0.0.3.dist-info/top_level.txt +1 -0
ohmqtt/__init__.py
ADDED
|
File without changes
|
ohmqtt/client.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import ssl
|
|
4
|
+
import threading
|
|
5
|
+
from typing import Final, Iterable, Sequence
|
|
6
|
+
import weakref
|
|
7
|
+
|
|
8
|
+
from .connection import Address, ConnectParams, Connection, MessageHandlers
|
|
9
|
+
from .logger import get_logger
|
|
10
|
+
from .mqtt_spec import MQTTReasonCode
|
|
11
|
+
from .packet import MQTTAuthPacket
|
|
12
|
+
from .property import MQTTAuthProps, MQTTConnectProps, MQTTPublishProps, MQTTWillProps
|
|
13
|
+
from .persistence.base import PublishHandle
|
|
14
|
+
from .session import Session
|
|
15
|
+
from .subscriptions import Subscriptions, SubscribeCallback, SubscribeHandle, UnsubscribeHandle, RetainPolicy
|
|
16
|
+
from .topic_alias import AliasPolicy
|
|
17
|
+
|
|
18
|
+
logger: Final = get_logger("client")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Client:
|
|
22
|
+
"""High level interface for the MQTT client."""
|
|
23
|
+
__slots__ = (
|
|
24
|
+
"__weakref__",
|
|
25
|
+
"_thread",
|
|
26
|
+
"connection",
|
|
27
|
+
"session",
|
|
28
|
+
"subscriptions",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def __init__(self, db_path: str = "", *, db_fast: bool = False) -> None:
|
|
32
|
+
self._thread: threading.Thread | None = None
|
|
33
|
+
message_handlers = MessageHandlers()
|
|
34
|
+
with message_handlers as handlers:
|
|
35
|
+
self.connection = Connection(handlers)
|
|
36
|
+
self.subscriptions = Subscriptions(handlers, self.connection, weakref.ref(self))
|
|
37
|
+
self.session = Session(handlers, self.connection, db_path=db_path, db_fast=db_fast)
|
|
38
|
+
handlers.register(MQTTAuthPacket, self.handle_auth)
|
|
39
|
+
|
|
40
|
+
def __enter__(self) -> Client:
|
|
41
|
+
self.start_loop()
|
|
42
|
+
return self
|
|
43
|
+
|
|
44
|
+
def __exit__(self, *args: object) -> None:
|
|
45
|
+
self.shutdown()
|
|
46
|
+
|
|
47
|
+
def connect(
|
|
48
|
+
self,
|
|
49
|
+
address: str,
|
|
50
|
+
*,
|
|
51
|
+
client_id: str = "",
|
|
52
|
+
clean_start: bool = False,
|
|
53
|
+
connect_timeout: float | None = None,
|
|
54
|
+
reconnect_delay: int = 0,
|
|
55
|
+
keepalive_interval: int = 0,
|
|
56
|
+
tcp_nodelay: bool = True,
|
|
57
|
+
tls_context: ssl.SSLContext | None = None,
|
|
58
|
+
tls_hostname: str = "",
|
|
59
|
+
will_topic: str = "",
|
|
60
|
+
will_payload: bytes = b"",
|
|
61
|
+
will_qos: int = 0,
|
|
62
|
+
will_retain: bool = False,
|
|
63
|
+
will_properties: MQTTWillProps | None = None,
|
|
64
|
+
connect_properties: MQTTConnectProps | None = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Connect to the broker."""
|
|
67
|
+
_address = Address(address)
|
|
68
|
+
params = ConnectParams(
|
|
69
|
+
address=_address,
|
|
70
|
+
client_id=client_id,
|
|
71
|
+
clean_start=clean_start,
|
|
72
|
+
connect_timeout=connect_timeout,
|
|
73
|
+
reconnect_delay=reconnect_delay,
|
|
74
|
+
keepalive_interval=keepalive_interval,
|
|
75
|
+
tcp_nodelay=tcp_nodelay,
|
|
76
|
+
tls_context=tls_context,
|
|
77
|
+
tls_hostname=tls_hostname,
|
|
78
|
+
will_topic=will_topic,
|
|
79
|
+
will_payload=will_payload,
|
|
80
|
+
will_qos=will_qos,
|
|
81
|
+
will_retain=will_retain,
|
|
82
|
+
will_properties=will_properties if will_properties is not None else MQTTWillProps(),
|
|
83
|
+
connect_properties=connect_properties if connect_properties is not None else MQTTConnectProps(),
|
|
84
|
+
)
|
|
85
|
+
self.session.set_params(params)
|
|
86
|
+
self.connection.connect(params)
|
|
87
|
+
|
|
88
|
+
def disconnect(self) -> None:
|
|
89
|
+
"""Disconnect from the broker."""
|
|
90
|
+
self.connection.disconnect()
|
|
91
|
+
|
|
92
|
+
def shutdown(self) -> None:
|
|
93
|
+
"""Shutdown the client and close the connection."""
|
|
94
|
+
self.connection.shutdown()
|
|
95
|
+
|
|
96
|
+
def publish(
|
|
97
|
+
self,
|
|
98
|
+
topic: str,
|
|
99
|
+
payload: bytes,
|
|
100
|
+
*,
|
|
101
|
+
qos: int = 0,
|
|
102
|
+
retain: bool = False,
|
|
103
|
+
properties: MQTTPublishProps | None = None,
|
|
104
|
+
alias_policy: AliasPolicy = AliasPolicy.NEVER,
|
|
105
|
+
) -> PublishHandle:
|
|
106
|
+
"""Publish a message to a topic."""
|
|
107
|
+
properties = properties if properties is not None else None
|
|
108
|
+
return self.session.publish(
|
|
109
|
+
topic,
|
|
110
|
+
payload,
|
|
111
|
+
qos=qos,
|
|
112
|
+
retain=retain,
|
|
113
|
+
properties=properties,
|
|
114
|
+
alias_policy=alias_policy,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def subscribe(
|
|
118
|
+
self,
|
|
119
|
+
topic_filter: str,
|
|
120
|
+
callback: SubscribeCallback,
|
|
121
|
+
max_qos: int = 2,
|
|
122
|
+
*,
|
|
123
|
+
share_name: str | None = None,
|
|
124
|
+
no_local: bool = False,
|
|
125
|
+
retain_as_published: bool = False,
|
|
126
|
+
retain_policy: RetainPolicy = RetainPolicy.ALWAYS,
|
|
127
|
+
sub_id: int | None = None,
|
|
128
|
+
user_properties: Sequence[tuple[str, str]] | None = None,
|
|
129
|
+
) -> SubscribeHandle | None:
|
|
130
|
+
"""Subscribe to a topic filter with a callback.
|
|
131
|
+
|
|
132
|
+
If the client is connected, returns a handle which can be used to unsubscribe from the topic filter
|
|
133
|
+
or wait for the subscription to be acknowledged.
|
|
134
|
+
|
|
135
|
+
If the client is not connected, returns None."""
|
|
136
|
+
return self.subscriptions.subscribe(
|
|
137
|
+
topic_filter,
|
|
138
|
+
callback,
|
|
139
|
+
max_qos=max_qos,
|
|
140
|
+
share_name=share_name,
|
|
141
|
+
no_local=no_local,
|
|
142
|
+
retain_as_published=retain_as_published,
|
|
143
|
+
retain_policy=retain_policy,
|
|
144
|
+
sub_id=sub_id,
|
|
145
|
+
user_properties=user_properties,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def unsubscribe(
|
|
149
|
+
# This method must have the same signature as the subscribe method.
|
|
150
|
+
# This lets us match the unsubscribe to the subscribe with the same args.
|
|
151
|
+
self,
|
|
152
|
+
topic_filter: str,
|
|
153
|
+
callback: SubscribeCallback,
|
|
154
|
+
max_qos: int = 2,
|
|
155
|
+
*,
|
|
156
|
+
share_name: str | None = None,
|
|
157
|
+
no_local: bool = False,
|
|
158
|
+
retain_as_published: bool = False,
|
|
159
|
+
retain_policy: RetainPolicy = RetainPolicy.ALWAYS,
|
|
160
|
+
sub_id: int | None = None,
|
|
161
|
+
user_properties: Iterable[tuple[str, str]] | None = None,
|
|
162
|
+
unsub_user_properties: Iterable[tuple[str, str]] | None = None,
|
|
163
|
+
) -> UnsubscribeHandle | None:
|
|
164
|
+
"""Unsubscribe from a topic filter.
|
|
165
|
+
|
|
166
|
+
If the client is connected, returns a handle which can be used to wait for the unsubscription to be acknowledged.
|
|
167
|
+
|
|
168
|
+
If the client is not connected, returns None."""
|
|
169
|
+
return self.subscriptions.unsubscribe(
|
|
170
|
+
topic_filter,
|
|
171
|
+
callback,
|
|
172
|
+
max_qos=max_qos,
|
|
173
|
+
share_name=share_name,
|
|
174
|
+
no_local=no_local,
|
|
175
|
+
retain_as_published=retain_as_published,
|
|
176
|
+
retain_policy=retain_policy,
|
|
177
|
+
sub_id=sub_id,
|
|
178
|
+
user_properties=user_properties,
|
|
179
|
+
unsub_user_properties=unsub_user_properties,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def auth(
|
|
183
|
+
self,
|
|
184
|
+
*,
|
|
185
|
+
authentication_method: str | None = None,
|
|
186
|
+
authentication_data: bytes | None = None,
|
|
187
|
+
reason_string: str | None = None,
|
|
188
|
+
user_properties: Sequence[tuple[str, str]] | None = None,
|
|
189
|
+
reason_code: MQTTReasonCode = MQTTReasonCode.Success,
|
|
190
|
+
) -> None:
|
|
191
|
+
"""Send an AUTH packet to the broker."""
|
|
192
|
+
properties = MQTTAuthProps()
|
|
193
|
+
if authentication_method is not None:
|
|
194
|
+
properties.AuthenticationMethod = authentication_method
|
|
195
|
+
if authentication_data is not None:
|
|
196
|
+
properties.AuthenticationData = authentication_data
|
|
197
|
+
if reason_string is not None:
|
|
198
|
+
properties.ReasonString = reason_string
|
|
199
|
+
if user_properties is not None:
|
|
200
|
+
properties.UserProperty = user_properties
|
|
201
|
+
packet = MQTTAuthPacket(
|
|
202
|
+
reason_code=reason_code,
|
|
203
|
+
properties=properties,
|
|
204
|
+
)
|
|
205
|
+
self.connection.send(packet)
|
|
206
|
+
|
|
207
|
+
def wait_for_connect(self, timeout: float | None = None) -> None:
|
|
208
|
+
"""Wait for the client to connect to the broker.
|
|
209
|
+
|
|
210
|
+
Raises TimeoutError if the timeout is exceeded."""
|
|
211
|
+
if not self.connection.wait_for_connect(timeout):
|
|
212
|
+
raise TimeoutError("Waiting for connection timed out")
|
|
213
|
+
|
|
214
|
+
def wait_for_disconnect(self, timeout: float | None = None) -> None:
|
|
215
|
+
"""Wait for the client to disconnect from the broker.
|
|
216
|
+
|
|
217
|
+
Raises TimeoutError if the timeout is exceeded."""
|
|
218
|
+
if not self.connection.wait_for_disconnect(timeout):
|
|
219
|
+
raise TimeoutError("Waiting for disconnection timed out")
|
|
220
|
+
|
|
221
|
+
def wait_for_shutdown(self, timeout: float | None = None) -> None:
|
|
222
|
+
"""Wait for the client to disconnect and finalize.
|
|
223
|
+
|
|
224
|
+
Raises TimeoutError if the timeout is exceeded."""
|
|
225
|
+
if not self.connection.wait_for_shutdown(timeout):
|
|
226
|
+
raise TimeoutError("Waiting for disconnection timed out")
|
|
227
|
+
|
|
228
|
+
def start_loop(self) -> None:
|
|
229
|
+
"""Start the client state machine in a separate thread."""
|
|
230
|
+
if self._thread is not None:
|
|
231
|
+
raise RuntimeError("Connection loop already started")
|
|
232
|
+
self._thread = threading.Thread(target=self.loop_forever, daemon=True)
|
|
233
|
+
self._thread.start()
|
|
234
|
+
|
|
235
|
+
def loop_once(self, max_wait: float | None = 0.0) -> None:
|
|
236
|
+
"""Run a single iteration of the MQTT client loop.
|
|
237
|
+
|
|
238
|
+
If max_wait is 0.0 (the default), this call will not block.
|
|
239
|
+
|
|
240
|
+
If max_wait is None, this call will block until the next event.
|
|
241
|
+
|
|
242
|
+
Any other numeric max_wait value may block for maximum that amount of time in seconds."""
|
|
243
|
+
self.connection.loop_once(max_wait)
|
|
244
|
+
|
|
245
|
+
def loop_forever(self) -> None:
|
|
246
|
+
"""Run the MQTT client loop.
|
|
247
|
+
|
|
248
|
+
This will run until the client is stopped or shutdown.
|
|
249
|
+
"""
|
|
250
|
+
self.connection.loop_forever()
|
|
251
|
+
|
|
252
|
+
def loop_until_connected(self, timeout: float | None = None) -> bool:
|
|
253
|
+
"""Run the MQTT client loop until the client is connected to the broker.
|
|
254
|
+
|
|
255
|
+
If a timeout is provided, the loop will give up after that amount of time.
|
|
256
|
+
|
|
257
|
+
Returns True if the client is connected, False if the timeout was reached."""
|
|
258
|
+
return self.connection.loop_until_connected(timeout)
|
|
259
|
+
|
|
260
|
+
def is_connected(self) -> bool:
|
|
261
|
+
"""Check if the client is connected to the broker."""
|
|
262
|
+
return self.connection.is_connected()
|
|
263
|
+
|
|
264
|
+
def handle_auth(self, packet: MQTTAuthPacket) -> None:
|
|
265
|
+
"""Callback for an AUTH packet from the broker."""
|
|
266
|
+
logger.debug("Got an AUTH packet")
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Final
|
|
4
|
+
|
|
5
|
+
from .address import Address as Address
|
|
6
|
+
from .fsm import FSM
|
|
7
|
+
from .fsm import InvalidStateError as InvalidStateError
|
|
8
|
+
from .handlers import MessageHandlers as MessageHandlers
|
|
9
|
+
from .states import (
|
|
10
|
+
ClosingState,
|
|
11
|
+
ClosedState,
|
|
12
|
+
ConnectingState,
|
|
13
|
+
ConnectedState,
|
|
14
|
+
ReconnectWaitState,
|
|
15
|
+
ShutdownState,
|
|
16
|
+
)
|
|
17
|
+
from .types import ConnectParams as ConnectParams
|
|
18
|
+
from .types import StateEnvironment, ReceivablePacketT, SendablePacketT
|
|
19
|
+
from ..error import MQTTError
|
|
20
|
+
from ..logger import get_logger
|
|
21
|
+
|
|
22
|
+
logger: Final = get_logger("connection")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Connection:
|
|
26
|
+
"""Interface for the MQTT connection."""
|
|
27
|
+
__slots__ = ("_handlers", "fsm")
|
|
28
|
+
|
|
29
|
+
def __init__(self, handlers: MessageHandlers) -> None:
|
|
30
|
+
state_env = StateEnvironment(packet_callback=self.handle_packet)
|
|
31
|
+
self.fsm = FSM(env=state_env, init_state=ClosedState, error_state=ShutdownState)
|
|
32
|
+
self._handlers = handlers
|
|
33
|
+
|
|
34
|
+
def handle_packet(self, packet: ReceivablePacketT) -> None:
|
|
35
|
+
"""Handle incoming packets by routing them to registered handlers."""
|
|
36
|
+
logger.debug("<--- %s", packet)
|
|
37
|
+
exceptions = self._handlers.handle(packet)
|
|
38
|
+
if exceptions:
|
|
39
|
+
if any(True for exc in exceptions if isinstance(exc, MQTTError)):
|
|
40
|
+
# If there is an MQTTError, raise it.
|
|
41
|
+
raise next(exc for exc in exceptions if isinstance(exc, MQTTError))
|
|
42
|
+
# Otherwise. raise the first exception.
|
|
43
|
+
raise exceptions[0]
|
|
44
|
+
|
|
45
|
+
def can_send(self) -> bool:
|
|
46
|
+
"""Check if the connection is in a state where data can be sent."""
|
|
47
|
+
with self.fsm.lock:
|
|
48
|
+
state = self.fsm.get_state()
|
|
49
|
+
return state is ConnectedState
|
|
50
|
+
|
|
51
|
+
def send(self, packet: SendablePacketT) -> None:
|
|
52
|
+
"""Send data to the connection."""
|
|
53
|
+
with self.fsm.lock:
|
|
54
|
+
if not self.can_send():
|
|
55
|
+
state = self.fsm.get_state()
|
|
56
|
+
raise InvalidStateError(f"Cannot send data in state {state.__name__}")
|
|
57
|
+
logger.debug("---> %s", packet)
|
|
58
|
+
data = packet.encode()
|
|
59
|
+
self.fsm.env.write_buffer.extend(data)
|
|
60
|
+
self.fsm.selector.interrupt()
|
|
61
|
+
|
|
62
|
+
def connect(self, params: ConnectParams) -> None:
|
|
63
|
+
"""Connect to the MQTT broker."""
|
|
64
|
+
with self.fsm.lock:
|
|
65
|
+
self.fsm.set_params(params)
|
|
66
|
+
self.fsm.request_state(ConnectingState)
|
|
67
|
+
self.fsm.selector.interrupt()
|
|
68
|
+
|
|
69
|
+
def disconnect(self) -> None:
|
|
70
|
+
"""Disconnect from the MQTT broker."""
|
|
71
|
+
with self.fsm.lock:
|
|
72
|
+
self.fsm.request_state(ClosingState)
|
|
73
|
+
self.fsm.selector.interrupt()
|
|
74
|
+
|
|
75
|
+
def shutdown(self) -> None:
|
|
76
|
+
"""Shutdown the connection."""
|
|
77
|
+
with self.fsm.lock:
|
|
78
|
+
self.fsm.request_state(ShutdownState)
|
|
79
|
+
self.fsm.selector.interrupt()
|
|
80
|
+
|
|
81
|
+
def is_connected(self) -> bool:
|
|
82
|
+
"""Check if the connection is established."""
|
|
83
|
+
return self.fsm.get_state() == ConnectedState
|
|
84
|
+
|
|
85
|
+
def wait_for_connect(self, timeout: float | None = None) -> bool:
|
|
86
|
+
"""Wait for the connection to be established.
|
|
87
|
+
|
|
88
|
+
Returns True if the connection is established, False if the timeout is reached."""
|
|
89
|
+
return self.fsm.wait_for_state((ConnectedState,), timeout)
|
|
90
|
+
|
|
91
|
+
def wait_for_disconnect(self, timeout: float | None = None) -> bool:
|
|
92
|
+
"""Wait for the connection to be closed.
|
|
93
|
+
|
|
94
|
+
Returns True if the connection is closed, False if the timeout is reached."""
|
|
95
|
+
return self.fsm.wait_for_state((ClosedState, ShutdownState, ReconnectWaitState), timeout)
|
|
96
|
+
|
|
97
|
+
def wait_for_shutdown(self, timeout: float | None = None) -> bool:
|
|
98
|
+
"""Wait for the connection to be closed and finalized.
|
|
99
|
+
|
|
100
|
+
Returns True if the connection is closed, False if the timeout is reached."""
|
|
101
|
+
return self.fsm.wait_for_state((ShutdownState,), timeout)
|
|
102
|
+
|
|
103
|
+
def loop_once(self, max_wait: float | None = 0.0) -> None:
|
|
104
|
+
"""Run a single iteration of the state machine.
|
|
105
|
+
|
|
106
|
+
If max_wait is None, wait indefinitely. Otherwise, wait for the specified time."""
|
|
107
|
+
self.fsm.loop_once(max_wait)
|
|
108
|
+
|
|
109
|
+
def loop_forever(self) -> None:
|
|
110
|
+
"""Run the state machine until the connection is closed."""
|
|
111
|
+
self.fsm.loop_until_state((ShutdownState,))
|
|
112
|
+
|
|
113
|
+
def loop_until_connected(self, timeout: float | None = None) -> bool:
|
|
114
|
+
"""Run the state machine until the connection is established.
|
|
115
|
+
|
|
116
|
+
Returns True if the connection is established, False otherwise."""
|
|
117
|
+
return self.fsm.loop_until_state((ConnectedState,), timeout)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
import socket
|
|
3
|
+
from typing import Final, Mapping
|
|
4
|
+
from urllib.parse import urlparse, ParseResult
|
|
5
|
+
|
|
6
|
+
from ..platform import AF_UNIX, HAS_AF_UNIX
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
DEFAULT_PORTS: Final[Mapping[str, int]] = {
|
|
10
|
+
"mqtt": 1883,
|
|
11
|
+
"mqtts": 8883,
|
|
12
|
+
"unix": 0,
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def is_ipv6(hostname: str) -> bool:
|
|
17
|
+
"""Check if the hostname is an IPv6 address."""
|
|
18
|
+
try:
|
|
19
|
+
socket.inet_pton(socket.AF_INET6, hostname)
|
|
20
|
+
return True
|
|
21
|
+
except socket.error:
|
|
22
|
+
return False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _get_family(parsed: ParseResult) -> socket.AddressFamily:
|
|
26
|
+
"""Get the address family based on the parsed URL scheme."""
|
|
27
|
+
if HAS_AF_UNIX and parsed.scheme == "unix":
|
|
28
|
+
return AF_UNIX
|
|
29
|
+
if parsed.scheme in ("mqtt", "mqtts"):
|
|
30
|
+
if not parsed.hostname:
|
|
31
|
+
raise ValueError("Hostname is required for mqtt and mqtts schemes")
|
|
32
|
+
if is_ipv6(parsed.hostname):
|
|
33
|
+
return socket.AF_INET6
|
|
34
|
+
return socket.AF_INET
|
|
35
|
+
raise ValueError(f"Unsupported scheme: {parsed.scheme}")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass(slots=True, init=False, frozen=True, repr=False)
|
|
39
|
+
class Address:
|
|
40
|
+
scheme: str
|
|
41
|
+
family: socket.AddressFamily
|
|
42
|
+
host: str
|
|
43
|
+
port: int
|
|
44
|
+
username: str | None
|
|
45
|
+
password: str | None
|
|
46
|
+
|
|
47
|
+
def __init__(self, address: str = "") -> None:
|
|
48
|
+
"""Parse the address string into family, host, port, username, and password."""
|
|
49
|
+
# Special case: empty address is allowed, but slots will be empty.
|
|
50
|
+
if not address:
|
|
51
|
+
return
|
|
52
|
+
if address.startswith("unix:"):
|
|
53
|
+
if not HAS_AF_UNIX:
|
|
54
|
+
raise ValueError("Unix socket support is not available on this platform")
|
|
55
|
+
elif "//" not in address:
|
|
56
|
+
# urlparse may choke on some network address we wish to support, unless we guarantee a //.
|
|
57
|
+
address = "//" + address
|
|
58
|
+
parsed = urlparse(address, scheme="mqtt")
|
|
59
|
+
object.__setattr__(self, "scheme", parsed.scheme)
|
|
60
|
+
object.__setattr__(self, "family", _get_family(parsed))
|
|
61
|
+
object.__setattr__(self, "host", parsed.hostname or parsed.path)
|
|
62
|
+
if not self.host:
|
|
63
|
+
raise ValueError("No path in address")
|
|
64
|
+
if HAS_AF_UNIX and self.family == AF_UNIX and self.host == "/":
|
|
65
|
+
raise ValueError("'/' is not a valid Unix socket path")
|
|
66
|
+
object.__setattr__(self, "port", parsed.port if parsed.port is not None else DEFAULT_PORTS[parsed.scheme])
|
|
67
|
+
object.__setattr__(self, "username", parsed.username)
|
|
68
|
+
object.__setattr__(self, "password", parsed.password)
|
|
69
|
+
|
|
70
|
+
def __repr__(self) -> str:
|
|
71
|
+
"""Return a string representation of the address."""
|
|
72
|
+
if not hasattr(self, "scheme"):
|
|
73
|
+
# Handle the empty case.
|
|
74
|
+
return "Address()"
|
|
75
|
+
userpw = ""
|
|
76
|
+
if self.username:
|
|
77
|
+
userpw = self.username
|
|
78
|
+
if self.password:
|
|
79
|
+
userpw += ":<hidden>"
|
|
80
|
+
if userpw:
|
|
81
|
+
host = f"{userpw}@{self.host}"
|
|
82
|
+
else:
|
|
83
|
+
host = self.host
|
|
84
|
+
if self.scheme != "unix":
|
|
85
|
+
host = f"{host}:{self.port}"
|
|
86
|
+
return f"Address({self.scheme}://{host})"
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def use_tls(self) -> bool:
|
|
90
|
+
"""Check if the address uses TLS."""
|
|
91
|
+
return getattr(self, "scheme", None) == "mqtts"
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
import socket
|
|
3
|
+
import ssl
|
|
4
|
+
from typing import NamedTuple, Final
|
|
5
|
+
|
|
6
|
+
from ..error import MQTTError
|
|
7
|
+
from ..logger import get_logger
|
|
8
|
+
from ..mqtt_spec import MQTTReasonCode
|
|
9
|
+
from ..packet import decode_packet_from_parts, MQTTPacket
|
|
10
|
+
|
|
11
|
+
logger: Final = get_logger("connection.decoder")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class VarintDecodeResult(NamedTuple):
|
|
15
|
+
"""Result of decoding a variable length integer, in part or whole.
|
|
16
|
+
|
|
17
|
+
This state can be used to resume decoding if the socket doesn't have enough data."""
|
|
18
|
+
value: int
|
|
19
|
+
multiplier: int
|
|
20
|
+
complete: bool
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
InitVarintDecodeState: Final = VarintDecodeResult(0, 1, False)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ClosedSocketError(Exception):
|
|
27
|
+
"""Exception raised when the socket is closed."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class WantReadError(Exception):
|
|
31
|
+
"""Indicates that the socket is not ready for reading."""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(slots=True)
|
|
35
|
+
class IncrementalDecoder:
|
|
36
|
+
"""Incremental decoder for MQTT messages coming in from a socket.
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
head: The first byte of the packet.
|
|
40
|
+
length: Variable integer decoding state of the packet length.
|
|
41
|
+
data: The remaining data of the packet.
|
|
42
|
+
"""
|
|
43
|
+
head: int = field(default=-1, init=False)
|
|
44
|
+
length: VarintDecodeResult = field(default=InitVarintDecodeState, init=False)
|
|
45
|
+
data: bytearray = field(init=False, default_factory=bytearray)
|
|
46
|
+
|
|
47
|
+
def reset(self) -> None:
|
|
48
|
+
"""Reset the decoder state."""
|
|
49
|
+
self.head = -1
|
|
50
|
+
self.length = InitVarintDecodeState
|
|
51
|
+
self.data.clear()
|
|
52
|
+
|
|
53
|
+
def _recv_one_byte(self, sock: socket.socket | ssl.SSLSocket) -> int:
|
|
54
|
+
"""Receive one byte from the socket.
|
|
55
|
+
|
|
56
|
+
Raises WantReadError if the socket is not ready for reading.
|
|
57
|
+
|
|
58
|
+
Raises ClosedSocketError if the socket is closed."""
|
|
59
|
+
try:
|
|
60
|
+
data = sock.recv(1)
|
|
61
|
+
except (BlockingIOError, ssl.SSLWantReadError, ssl.SSLWantWriteError):
|
|
62
|
+
raise WantReadError("Socket not ready for reading")
|
|
63
|
+
if not data:
|
|
64
|
+
raise ClosedSocketError("Socket closed")
|
|
65
|
+
return data[0]
|
|
66
|
+
|
|
67
|
+
def _extract_head(self, sock: socket.socket | ssl.SSLSocket) -> None:
|
|
68
|
+
"""Extract the head byte of the next packet from the socket, if needed."""
|
|
69
|
+
if self.head != -1:
|
|
70
|
+
return
|
|
71
|
+
self.head = self._recv_one_byte(sock)
|
|
72
|
+
|
|
73
|
+
def _extract_length(self, sock: socket.socket | ssl.SSLSocket) -> None:
|
|
74
|
+
"""Incrementally decode a variable length integer from a socket, if needed.
|
|
75
|
+
|
|
76
|
+
Raises WantReadError if the socket is not ready for reading."""
|
|
77
|
+
# See ohmqtt.serialization.decode_varint for a cleaner implementation.
|
|
78
|
+
assert self.head != -1 # We shouldn't be here unless we have a head byte.
|
|
79
|
+
if self.length.complete:
|
|
80
|
+
return
|
|
81
|
+
result = self.length.value
|
|
82
|
+
mult = self.length.multiplier
|
|
83
|
+
try:
|
|
84
|
+
while mult < 0x200000: # This magic is the mult value after pulling 4 bytes.
|
|
85
|
+
byte = self._recv_one_byte(sock)
|
|
86
|
+
result += byte % 0x80 * mult
|
|
87
|
+
if byte < 0x80:
|
|
88
|
+
# We have the complete varint.
|
|
89
|
+
self.length = VarintDecodeResult(result, mult, True)
|
|
90
|
+
return
|
|
91
|
+
mult *= 0x80
|
|
92
|
+
raise MQTTError("Varint overflow", MQTTReasonCode.MalformedPacket)
|
|
93
|
+
except WantReadError:
|
|
94
|
+
# Not done yet, the socket is neither closed nor ready for reading.
|
|
95
|
+
# Save the partial state and return.
|
|
96
|
+
self.length = VarintDecodeResult(result, mult, False)
|
|
97
|
+
raise
|
|
98
|
+
|
|
99
|
+
def _extract_data(self, sock: socket.socket | ssl.SSLSocket) -> None:
|
|
100
|
+
"""Extract all data after the packet length from the socket, if needed."""
|
|
101
|
+
assert self.length.complete # We shouldn't be here unless we have a complete length.
|
|
102
|
+
while len(self.data) < self.length.value:
|
|
103
|
+
data = sock.recv(self.length.value - len(self.data))
|
|
104
|
+
if not data:
|
|
105
|
+
raise ClosedSocketError("Socket closed")
|
|
106
|
+
self.data.extend(data)
|
|
107
|
+
|
|
108
|
+
def decode(self, sock: socket.socket | ssl.SSLSocket) -> MQTTPacket | None:
|
|
109
|
+
"""Decode a single packet straight from the socket.
|
|
110
|
+
|
|
111
|
+
Returns None if the socket doesn't have enough data for us to decode a packet.
|
|
112
|
+
|
|
113
|
+
Raises ClosedSocketError if the socket is closed."""
|
|
114
|
+
try:
|
|
115
|
+
self._extract_head(sock)
|
|
116
|
+
self._extract_length(sock)
|
|
117
|
+
self._extract_data(sock)
|
|
118
|
+
except (BlockingIOError, ssl.SSLWantReadError, ssl.SSLWantWriteError, WantReadError):
|
|
119
|
+
# If the socket is open but doesn't have enough data for us, we need to wait for more.
|
|
120
|
+
return None
|
|
121
|
+
except OSError as exc:
|
|
122
|
+
raise ClosedSocketError("Socket closed") from exc
|
|
123
|
+
|
|
124
|
+
# We have a complete packet, decode it and clear the read buffer.
|
|
125
|
+
packet_head = self.head
|
|
126
|
+
packet_data = memoryview(self.data)
|
|
127
|
+
packet_data.toreadonly()
|
|
128
|
+
try:
|
|
129
|
+
return decode_packet_from_parts(packet_head, packet_data)
|
|
130
|
+
finally:
|
|
131
|
+
packet_data.release()
|
|
132
|
+
self.reset()
|