python-oa3-client 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadr3_client/__init__.py +41 -0
- openadr3_client/base.py +141 -0
- openadr3_client/bl.py +18 -0
- openadr3_client/discovery.py +253 -0
- openadr3_client/mqtt.py +190 -0
- openadr3_client/notifications.py +159 -0
- openadr3_client/ven.py +248 -0
- openadr3_client/webhook.py +232 -0
- python_oa3_client-0.1.0.dist-info/METADATA +387 -0
- python_oa3_client-0.1.0.dist-info/RECORD +12 -0
- python_oa3_client-0.1.0.dist-info/WHEEL +4 -0
- python_oa3_client-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from openadr3_client.base import BaseClient
|
|
2
|
+
from openadr3_client.ven import VenClient, extract_topics
|
|
3
|
+
from openadr3_client.bl import BlClient
|
|
4
|
+
from openadr3_client.notifications import (
|
|
5
|
+
MqttChannel,
|
|
6
|
+
NotificationChannel,
|
|
7
|
+
WebhookChannel,
|
|
8
|
+
)
|
|
9
|
+
from openadr3_client.mqtt import MQTTConnection, MQTTMessage, normalize_broker_uri
|
|
10
|
+
from openadr3_client.webhook import WebhookReceiver, WebhookMessage, detect_lan_ip
|
|
11
|
+
from openadr3_client.discovery import (
|
|
12
|
+
DiscoveredVTN,
|
|
13
|
+
DiscoveryMode,
|
|
14
|
+
advertise_vtn,
|
|
15
|
+
discover_vtns,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
# Clients
|
|
20
|
+
"VenClient",
|
|
21
|
+
"BlClient",
|
|
22
|
+
"BaseClient",
|
|
23
|
+
# Discovery
|
|
24
|
+
"DiscoveredVTN",
|
|
25
|
+
"DiscoveryMode",
|
|
26
|
+
"discover_vtns",
|
|
27
|
+
"advertise_vtn",
|
|
28
|
+
# Notification channels
|
|
29
|
+
"MqttChannel",
|
|
30
|
+
"WebhookChannel",
|
|
31
|
+
"NotificationChannel",
|
|
32
|
+
# Low-level (still public)
|
|
33
|
+
"MQTTConnection",
|
|
34
|
+
"MQTTMessage",
|
|
35
|
+
"WebhookReceiver",
|
|
36
|
+
"WebhookMessage",
|
|
37
|
+
# Helpers
|
|
38
|
+
"extract_topics",
|
|
39
|
+
"normalize_broker_uri",
|
|
40
|
+
"detect_lan_ip",
|
|
41
|
+
]
|
openadr3_client/base.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Base client with auth, lifecycle, and __getattr__ delegation to OpenADRClient."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import threading
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from openadr3.api import (
|
|
10
|
+
OpenADRClient,
|
|
11
|
+
create_bl_client,
|
|
12
|
+
create_ven_client,
|
|
13
|
+
)
|
|
14
|
+
from openadr3.auth import fetch_token
|
|
15
|
+
|
|
16
|
+
from openadr3_client.discovery import DiscoveryMode, resolve_url
|
|
17
|
+
|
|
18
|
+
log = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BaseClient:
|
|
22
|
+
"""Lifecycle-managed OpenADR 3 client with __getattr__ delegation.
|
|
23
|
+
|
|
24
|
+
Any attribute not found on this class is forwarded to the underlying
|
|
25
|
+
OpenADRClient, eliminating the need for explicit delegation methods.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
_client_type: str = "ven"
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
url: str | None = None,
|
|
33
|
+
token: str | None = None,
|
|
34
|
+
client_id: str | None = None,
|
|
35
|
+
client_secret: str | None = None,
|
|
36
|
+
spec_version: str = "3.1.0",
|
|
37
|
+
spec_path: str | None = None,
|
|
38
|
+
validate: bool = False,
|
|
39
|
+
discovery: str | DiscoveryMode = "never",
|
|
40
|
+
discovery_timeout: float = 3.0,
|
|
41
|
+
) -> None:
|
|
42
|
+
if not token and not (client_id and client_secret):
|
|
43
|
+
raise ValueError("Provide either token or both client_id and client_secret")
|
|
44
|
+
|
|
45
|
+
self.discovery_mode = DiscoveryMode(discovery)
|
|
46
|
+
self.discovery_timeout = discovery_timeout
|
|
47
|
+
|
|
48
|
+
if self.discovery_mode == DiscoveryMode.NEVER and not url:
|
|
49
|
+
raise ValueError("url is required when discovery='never'")
|
|
50
|
+
if self.discovery_mode == DiscoveryMode.LOCAL_WITH_FALLBACK and not url:
|
|
51
|
+
raise ValueError("url is required when discovery='local_with_fallback'")
|
|
52
|
+
|
|
53
|
+
self.url = url
|
|
54
|
+
self.token = token
|
|
55
|
+
self.client_id = client_id
|
|
56
|
+
self.client_secret = client_secret
|
|
57
|
+
self.spec_version = spec_version
|
|
58
|
+
self.spec_path = spec_path
|
|
59
|
+
self.validate = validate
|
|
60
|
+
|
|
61
|
+
self._resolved_url: str | None = None
|
|
62
|
+
self._api: OpenADRClient | None = None
|
|
63
|
+
self._lock = threading.Lock()
|
|
64
|
+
|
|
65
|
+
# -- Lifecycle --
|
|
66
|
+
|
|
67
|
+
def start(self) -> BaseClient:
|
|
68
|
+
"""Start the client — creates the underlying OpenADRClient.
|
|
69
|
+
|
|
70
|
+
Resolves the VTN URL via mDNS discovery (if configured), then
|
|
71
|
+
fetches an auth token if needed, and creates the OpenADRClient.
|
|
72
|
+
"""
|
|
73
|
+
if self._api:
|
|
74
|
+
log.info(
|
|
75
|
+
"%s already started: url=%s",
|
|
76
|
+
type(self).__name__, self._resolved_url,
|
|
77
|
+
)
|
|
78
|
+
return self
|
|
79
|
+
|
|
80
|
+
self._resolved_url = resolve_url(
|
|
81
|
+
self.discovery_mode, self.url, self.discovery_timeout,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
if not self.token:
|
|
85
|
+
self.token = fetch_token(
|
|
86
|
+
base_url=self._resolved_url,
|
|
87
|
+
client_id=self.client_id,
|
|
88
|
+
client_secret=self.client_secret,
|
|
89
|
+
)
|
|
90
|
+
log.info("Token fetched via client credentials: client_id=%s", self.client_id)
|
|
91
|
+
|
|
92
|
+
create_fn = create_ven_client if self._client_type == "ven" else create_bl_client
|
|
93
|
+
self._api = create_fn(
|
|
94
|
+
base_url=self._resolved_url,
|
|
95
|
+
token=self.token,
|
|
96
|
+
spec_path=self.spec_path,
|
|
97
|
+
validate=self.validate,
|
|
98
|
+
)
|
|
99
|
+
log.info(
|
|
100
|
+
"%s started: type=%s url=%s",
|
|
101
|
+
type(self).__name__, self._client_type, self._resolved_url,
|
|
102
|
+
)
|
|
103
|
+
return self
|
|
104
|
+
|
|
105
|
+
def stop(self) -> BaseClient:
|
|
106
|
+
"""Stop the client — close HTTP connection."""
|
|
107
|
+
if self._api:
|
|
108
|
+
self._api.close()
|
|
109
|
+
self._api = None
|
|
110
|
+
log.info("%s stopped", type(self).__name__)
|
|
111
|
+
return self
|
|
112
|
+
|
|
113
|
+
def __enter__(self):
|
|
114
|
+
return self.start()
|
|
115
|
+
|
|
116
|
+
def __exit__(self, *args: Any) -> None:
|
|
117
|
+
self.stop()
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def api(self) -> OpenADRClient:
|
|
121
|
+
"""The underlying OpenADRClient. Raises if not started."""
|
|
122
|
+
if not self._api:
|
|
123
|
+
raise RuntimeError(
|
|
124
|
+
f"{type(self).__name__} not started. Call start() first."
|
|
125
|
+
)
|
|
126
|
+
return self._api
|
|
127
|
+
|
|
128
|
+
# -- __getattr__ delegation --
|
|
129
|
+
|
|
130
|
+
def __getattr__(self, name: str) -> Any:
|
|
131
|
+
"""Forward unknown attributes to the underlying OpenADRClient."""
|
|
132
|
+
# Only delegate if _api exists and is set (avoid recursion during __init__)
|
|
133
|
+
api = self.__dict__.get("_api")
|
|
134
|
+
if api is not None:
|
|
135
|
+
try:
|
|
136
|
+
return getattr(api, name)
|
|
137
|
+
except AttributeError:
|
|
138
|
+
pass
|
|
139
|
+
raise AttributeError(
|
|
140
|
+
f"'{type(self).__name__}' object has no attribute '{name}'"
|
|
141
|
+
)
|
openadr3_client/bl.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""BlClient — Business Logic client for program/event management."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from openadr3_client.base import BaseClient
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BlClient(BaseClient):
|
|
9
|
+
"""OpenADR 3 Business Logic client.
|
|
10
|
+
|
|
11
|
+
Thin wrapper over BaseClient that sets client_type to "bl".
|
|
12
|
+
BL clients create and manage programs and events — no VEN
|
|
13
|
+
registration or notification concepts.
|
|
14
|
+
|
|
15
|
+
All OpenADRClient methods are available via __getattr__ delegation.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
_client_type = "bl"
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""mDNS/DNS-SD discovery for OpenADR 3 VTNs.
|
|
2
|
+
|
|
3
|
+
Browses for ``_openadr3._tcp.local.`` services via zeroconf and parses
|
|
4
|
+
TXT records defined in the OpenADR 3.1.0 specification (section 4.1).
|
|
5
|
+
|
|
6
|
+
The ``zeroconf`` package is lazily imported so that the core library
|
|
7
|
+
works without it installed — add the ``[mdns]`` extra to pull it in.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
import threading
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from enum import Enum
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
log = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
SERVICE_TYPE = "_openadr3._tcp.local."
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DiscoveryMode(str, Enum):
|
|
24
|
+
"""How the client should discover a VTN."""
|
|
25
|
+
|
|
26
|
+
NEVER = "never"
|
|
27
|
+
PREFER_LOCAL = "prefer_local"
|
|
28
|
+
LOCAL_WITH_FALLBACK = "local_with_fallback"
|
|
29
|
+
REQUIRE_LOCAL = "require_local"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _parse_txt_properties(properties: dict[bytes | str, bytes | str | None]) -> dict[str, str]:
|
|
33
|
+
"""Convert zeroconf TXT record properties to ``{str: str}``."""
|
|
34
|
+
result: dict[str, str] = {}
|
|
35
|
+
for k, v in properties.items():
|
|
36
|
+
key = k.decode("utf-8") if isinstance(k, bytes) else k
|
|
37
|
+
if v is None:
|
|
38
|
+
result[key] = ""
|
|
39
|
+
elif isinstance(v, bytes):
|
|
40
|
+
result[key] = v.decode("utf-8")
|
|
41
|
+
else:
|
|
42
|
+
result[key] = str(v)
|
|
43
|
+
return result
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass(frozen=True)
|
|
47
|
+
class DiscoveredVTN:
|
|
48
|
+
"""A VTN discovered via mDNS/DNS-SD."""
|
|
49
|
+
|
|
50
|
+
name: str
|
|
51
|
+
host: str
|
|
52
|
+
port: int
|
|
53
|
+
base_path: str = "/"
|
|
54
|
+
version: str = ""
|
|
55
|
+
local_url: str = ""
|
|
56
|
+
program_names: str = ""
|
|
57
|
+
requires_auth: str = ""
|
|
58
|
+
openapi_url: str = ""
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def url(self) -> str:
|
|
62
|
+
"""Resolved URL: ``local_url`` if set, else constructed from host/port/base_path."""
|
|
63
|
+
if self.local_url:
|
|
64
|
+
return self.local_url.rstrip("/")
|
|
65
|
+
scheme = "https" if self.port == 443 else "http"
|
|
66
|
+
base = self.base_path.rstrip("/") if self.base_path != "/" else ""
|
|
67
|
+
return f"{scheme}://{self.host}:{self.port}{base}"
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def from_service_info(cls, info: Any) -> DiscoveredVTN:
|
|
71
|
+
"""Build from a ``zeroconf.ServiceInfo`` object."""
|
|
72
|
+
props = _parse_txt_properties(info.properties or {})
|
|
73
|
+
# Prefer .server (the .local hostname) over parsed addresses
|
|
74
|
+
host = info.server.rstrip(".") if info.server else (
|
|
75
|
+
info.parsed_addresses()[0] if info.parsed_addresses() else "localhost"
|
|
76
|
+
)
|
|
77
|
+
return cls(
|
|
78
|
+
name=info.name,
|
|
79
|
+
host=host,
|
|
80
|
+
port=info.port,
|
|
81
|
+
base_path=props.get("base_path", "/"),
|
|
82
|
+
version=props.get("version", ""),
|
|
83
|
+
local_url=props.get("local_url", ""),
|
|
84
|
+
program_names=props.get("program_names", ""),
|
|
85
|
+
requires_auth=props.get("requires_auth", ""),
|
|
86
|
+
openapi_url=props.get("openapi_url", ""),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _import_zeroconf():
|
|
91
|
+
"""Lazy-import zeroconf with a helpful error message."""
|
|
92
|
+
try:
|
|
93
|
+
import zeroconf # noqa: F811
|
|
94
|
+
return zeroconf
|
|
95
|
+
except ImportError:
|
|
96
|
+
raise ImportError(
|
|
97
|
+
"zeroconf is required for mDNS discovery. "
|
|
98
|
+
"Install it with: pip install 'python-oa3-client[mdns]'"
|
|
99
|
+
) from None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def discover_vtns(timeout: float = 3.0) -> list[DiscoveredVTN]:
|
|
103
|
+
"""Browse for OpenADR 3 VTNs via mDNS.
|
|
104
|
+
|
|
105
|
+
Blocks for *timeout* seconds while collecting service announcements,
|
|
106
|
+
then returns all discovered VTNs.
|
|
107
|
+
"""
|
|
108
|
+
zc_mod = _import_zeroconf()
|
|
109
|
+
Zeroconf = zc_mod.Zeroconf
|
|
110
|
+
ServiceBrowser = zc_mod.ServiceBrowser
|
|
111
|
+
|
|
112
|
+
found: list[DiscoveredVTN] = []
|
|
113
|
+
event = threading.Event()
|
|
114
|
+
|
|
115
|
+
class Listener:
|
|
116
|
+
def add_service(self, zc: Any, type_: str, name: str) -> None:
|
|
117
|
+
info = zc.get_service_info(type_, name)
|
|
118
|
+
if info:
|
|
119
|
+
vtn = DiscoveredVTN.from_service_info(info)
|
|
120
|
+
found.append(vtn)
|
|
121
|
+
log.info("Discovered VTN: %s at %s", vtn.name, vtn.url)
|
|
122
|
+
|
|
123
|
+
def remove_service(self, zc: Any, type_: str, name: str) -> None:
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
def update_service(self, zc: Any, type_: str, name: str) -> None:
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
zc = Zeroconf()
|
|
130
|
+
try:
|
|
131
|
+
ServiceBrowser(zc, SERVICE_TYPE, Listener())
|
|
132
|
+
event.wait(timeout)
|
|
133
|
+
finally:
|
|
134
|
+
zc.close()
|
|
135
|
+
|
|
136
|
+
return found
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def resolve_url(
|
|
140
|
+
mode: DiscoveryMode | str,
|
|
141
|
+
configured_url: str | None,
|
|
142
|
+
timeout: float = 3.0,
|
|
143
|
+
) -> str:
|
|
144
|
+
"""Resolve the VTN URL based on discovery mode.
|
|
145
|
+
|
|
146
|
+
Called by ``BaseClient.start()`` to determine the final URL.
|
|
147
|
+
"""
|
|
148
|
+
mode = DiscoveryMode(mode)
|
|
149
|
+
|
|
150
|
+
if mode == DiscoveryMode.NEVER:
|
|
151
|
+
if not configured_url:
|
|
152
|
+
raise ValueError("url is required when discovery='never'")
|
|
153
|
+
return configured_url
|
|
154
|
+
|
|
155
|
+
vtns = discover_vtns(timeout=timeout)
|
|
156
|
+
discovered_url = vtns[0].url if vtns else None
|
|
157
|
+
|
|
158
|
+
if mode == DiscoveryMode.REQUIRE_LOCAL:
|
|
159
|
+
if not discovered_url:
|
|
160
|
+
raise RuntimeError(
|
|
161
|
+
"discovery='require_local' but no VTN found via mDNS"
|
|
162
|
+
)
|
|
163
|
+
return discovered_url
|
|
164
|
+
|
|
165
|
+
if mode == DiscoveryMode.PREFER_LOCAL:
|
|
166
|
+
if discovered_url:
|
|
167
|
+
return discovered_url
|
|
168
|
+
if configured_url:
|
|
169
|
+
return configured_url
|
|
170
|
+
raise RuntimeError(
|
|
171
|
+
"discovery='prefer_local': no VTN found via mDNS and no url configured"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# LOCAL_WITH_FALLBACK
|
|
175
|
+
if discovered_url:
|
|
176
|
+
return discovered_url
|
|
177
|
+
# configured_url is guaranteed non-None by __init__ validation
|
|
178
|
+
return configured_url # type: ignore[return-value]
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class _VTNAdvertiser:
|
|
182
|
+
"""Context manager that registers an mDNS service for a VTN."""
|
|
183
|
+
|
|
184
|
+
def __init__(self, info: Any, zc: Any) -> None:
|
|
185
|
+
self._info = info
|
|
186
|
+
self._zc = zc
|
|
187
|
+
|
|
188
|
+
def close(self) -> None:
|
|
189
|
+
"""Unregister the service and shut down zeroconf."""
|
|
190
|
+
self._zc.unregister_service(self._info)
|
|
191
|
+
self._zc.close()
|
|
192
|
+
log.info("mDNS service unregistered: %s", self._info.name)
|
|
193
|
+
|
|
194
|
+
def __enter__(self) -> _VTNAdvertiser:
|
|
195
|
+
return self
|
|
196
|
+
|
|
197
|
+
def __exit__(self, *args: Any) -> None:
|
|
198
|
+
self.close()
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def advertise_vtn(
|
|
202
|
+
host: str,
|
|
203
|
+
port: int,
|
|
204
|
+
base_path: str = "/",
|
|
205
|
+
version: str = "3.1.0",
|
|
206
|
+
local_url: str = "",
|
|
207
|
+
program_names: str = "",
|
|
208
|
+
requires_auth: str = "false",
|
|
209
|
+
openapi_url: str = "",
|
|
210
|
+
name: str = "OpenADR3 VTN",
|
|
211
|
+
) -> _VTNAdvertiser:
|
|
212
|
+
"""Register an mDNS service advertising a VTN.
|
|
213
|
+
|
|
214
|
+
Returns a context manager / object with ``.close()`` to unregister.
|
|
215
|
+
|
|
216
|
+
Useful for testing discovery against a running VTN-RI without
|
|
217
|
+
modifying VTN-RI itself.
|
|
218
|
+
"""
|
|
219
|
+
zc_mod = _import_zeroconf()
|
|
220
|
+
import socket
|
|
221
|
+
|
|
222
|
+
Zeroconf = zc_mod.Zeroconf
|
|
223
|
+
ServiceInfo = zc_mod.ServiceInfo
|
|
224
|
+
|
|
225
|
+
properties = {
|
|
226
|
+
"version": version,
|
|
227
|
+
"base_path": base_path,
|
|
228
|
+
}
|
|
229
|
+
if local_url:
|
|
230
|
+
properties["local_url"] = local_url
|
|
231
|
+
if program_names:
|
|
232
|
+
properties["program_names"] = program_names
|
|
233
|
+
if requires_auth:
|
|
234
|
+
properties["requires_auth"] = requires_auth
|
|
235
|
+
if openapi_url:
|
|
236
|
+
properties["openapi_url"] = openapi_url
|
|
237
|
+
|
|
238
|
+
info = ServiceInfo(
|
|
239
|
+
SERVICE_TYPE,
|
|
240
|
+
f"{name}.{SERVICE_TYPE}",
|
|
241
|
+
server=f"{host}.",
|
|
242
|
+
port=port,
|
|
243
|
+
properties=properties,
|
|
244
|
+
addresses=[socket.inet_aton(
|
|
245
|
+
"127.0.0.1" if host in ("localhost", "127.0.0.1") else host
|
|
246
|
+
)],
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
zc = Zeroconf()
|
|
250
|
+
zc.register_service(info)
|
|
251
|
+
log.info("mDNS service registered: %s on %s:%d", name, host, port)
|
|
252
|
+
|
|
253
|
+
return _VTNAdvertiser(info, zc)
|
openadr3_client/mqtt.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""MQTT notification support for OpenADR 3 clients.
|
|
2
|
+
|
|
3
|
+
Connects to an MQTT broker via ebus-mqtt-client and collects messages
|
|
4
|
+
in a thread-safe list. Payloads that look like OpenADR notifications
|
|
5
|
+
are automatically coerced into Notification models.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import re
|
|
13
|
+
import threading
|
|
14
|
+
import time
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from typing import Any, Callable
|
|
17
|
+
from urllib.parse import urlparse
|
|
18
|
+
|
|
19
|
+
from openadr3.entities import coerce_notification, is_notification
|
|
20
|
+
|
|
21
|
+
log = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def normalize_broker_uri(uri: str) -> tuple[str, int, bool]:
|
|
25
|
+
"""Translate an MQTT URI into (host, port, use_tls).
|
|
26
|
+
|
|
27
|
+
Supports mqtt://, mqtts://, tcp://, ssl:// schemes.
|
|
28
|
+
Adds default ports (1883 for plain, 8883 for TLS) when omitted.
|
|
29
|
+
"""
|
|
30
|
+
parsed = urlparse(uri)
|
|
31
|
+
scheme = parsed.scheme.lower()
|
|
32
|
+
host = parsed.hostname or "127.0.0.1"
|
|
33
|
+
|
|
34
|
+
if scheme in ("mqtts", "ssl"):
|
|
35
|
+
use_tls = True
|
|
36
|
+
port = parsed.port or 8883
|
|
37
|
+
else:
|
|
38
|
+
use_tls = False
|
|
39
|
+
port = parsed.port or 1883
|
|
40
|
+
|
|
41
|
+
return host, port, use_tls
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _parse_payload(raw: bytes, topic: str) -> Any:
|
|
45
|
+
"""Parse MQTT payload bytes as JSON, coercing notifications."""
|
|
46
|
+
try:
|
|
47
|
+
s = raw.decode("utf-8")
|
|
48
|
+
except UnicodeDecodeError:
|
|
49
|
+
return raw
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
parsed = json.loads(s)
|
|
53
|
+
except (json.JSONDecodeError, ValueError):
|
|
54
|
+
return s
|
|
55
|
+
|
|
56
|
+
if isinstance(parsed, dict) and is_notification(parsed):
|
|
57
|
+
return coerce_notification(
|
|
58
|
+
parsed, {"openadr/channel": "mqtt", "openadr/topic": topic}
|
|
59
|
+
)
|
|
60
|
+
return parsed
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class MQTTMessage:
|
|
65
|
+
"""A received MQTT message."""
|
|
66
|
+
|
|
67
|
+
topic: str
|
|
68
|
+
payload: Any
|
|
69
|
+
time: float
|
|
70
|
+
raw_payload: bytes
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class MQTTConnection:
|
|
74
|
+
"""MQTT connection with thread-safe message collection.
|
|
75
|
+
|
|
76
|
+
Wraps ebus-mqtt-client's MqttClient, adding:
|
|
77
|
+
- Message collection in a thread-safe list
|
|
78
|
+
- Notification payload coercion
|
|
79
|
+
- Await helpers for testing
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
broker_url: str,
|
|
85
|
+
client_id: str | None = None,
|
|
86
|
+
on_message: Callable[[str, Any], None] | None = None,
|
|
87
|
+
) -> None:
|
|
88
|
+
self.broker_url = broker_url
|
|
89
|
+
self.client_id = client_id or f"oa3-{id(self):x}"
|
|
90
|
+
self.on_message_callback = on_message
|
|
91
|
+
self._messages: list[MQTTMessage] = []
|
|
92
|
+
self._lock = threading.Lock()
|
|
93
|
+
self._client: MqttClient | None = None
|
|
94
|
+
|
|
95
|
+
def connect(self) -> None:
|
|
96
|
+
"""Connect to the MQTT broker."""
|
|
97
|
+
try:
|
|
98
|
+
from ebus_mqtt_client import MqttClient
|
|
99
|
+
except ImportError:
|
|
100
|
+
raise ImportError(
|
|
101
|
+
"ebus-mqtt-client is required for MQTT support. "
|
|
102
|
+
"Install it with: pip install python-oa3-client[mqtt]"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
host, port, use_tls = normalize_broker_uri(self.broker_url)
|
|
106
|
+
self._client = MqttClient(
|
|
107
|
+
client_id=self.client_id,
|
|
108
|
+
endpoint=host,
|
|
109
|
+
port=port,
|
|
110
|
+
use_tls=use_tls,
|
|
111
|
+
tls_insecure=True,
|
|
112
|
+
)
|
|
113
|
+
self._client.start()
|
|
114
|
+
log.info("MQTT connected: broker=%s client_id=%s", self.broker_url, self.client_id)
|
|
115
|
+
|
|
116
|
+
def disconnect(self) -> None:
|
|
117
|
+
"""Disconnect from the MQTT broker."""
|
|
118
|
+
if self._client:
|
|
119
|
+
self._client.stop()
|
|
120
|
+
log.info("MQTT disconnected: broker=%s", self.broker_url)
|
|
121
|
+
self._client = None
|
|
122
|
+
|
|
123
|
+
def is_connected(self) -> bool:
|
|
124
|
+
return self._client.is_connected() if self._client else False
|
|
125
|
+
|
|
126
|
+
def subscribe(self, topics: list[str] | str) -> None:
|
|
127
|
+
"""Subscribe to one or more MQTT topics."""
|
|
128
|
+
if not self._client:
|
|
129
|
+
raise RuntimeError("Not connected. Call connect() first.")
|
|
130
|
+
if isinstance(topics, str):
|
|
131
|
+
topics = [topics]
|
|
132
|
+
for topic in topics:
|
|
133
|
+
self._client.subscribe(topic, self._handle_message)
|
|
134
|
+
log.info("MQTT subscribed: topics=%s", topics)
|
|
135
|
+
|
|
136
|
+
def _handle_message(self, topic: str, payload: bytes) -> None:
|
|
137
|
+
"""Internal callback — parse, collect, and dispatch."""
|
|
138
|
+
parsed = _parse_payload(payload, topic)
|
|
139
|
+
msg = MQTTMessage(
|
|
140
|
+
topic=topic,
|
|
141
|
+
payload=parsed,
|
|
142
|
+
time=time.time(),
|
|
143
|
+
raw_payload=payload,
|
|
144
|
+
)
|
|
145
|
+
with self._lock:
|
|
146
|
+
self._messages.append(msg)
|
|
147
|
+
log.debug("MQTT message: topic=%s", topic)
|
|
148
|
+
if self.on_message_callback:
|
|
149
|
+
self.on_message_callback(topic, parsed)
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def messages(self) -> list[MQTTMessage]:
|
|
153
|
+
"""All collected messages (snapshot)."""
|
|
154
|
+
with self._lock:
|
|
155
|
+
return list(self._messages)
|
|
156
|
+
|
|
157
|
+
def messages_on_topic(self, topic: str) -> list[MQTTMessage]:
|
|
158
|
+
"""Messages received on a specific topic."""
|
|
159
|
+
with self._lock:
|
|
160
|
+
return [m for m in self._messages if m.topic == topic]
|
|
161
|
+
|
|
162
|
+
def clear_messages(self) -> None:
|
|
163
|
+
"""Clear collected messages."""
|
|
164
|
+
with self._lock:
|
|
165
|
+
self._messages.clear()
|
|
166
|
+
|
|
167
|
+
def await_messages(self, n: int, timeout: float = 5.0) -> list[MQTTMessage]:
|
|
168
|
+
"""Wait until at least n messages collected, or timeout."""
|
|
169
|
+
deadline = time.time() + timeout
|
|
170
|
+
while True:
|
|
171
|
+
with self._lock:
|
|
172
|
+
if len(self._messages) >= n:
|
|
173
|
+
return list(self._messages)
|
|
174
|
+
if time.time() >= deadline:
|
|
175
|
+
with self._lock:
|
|
176
|
+
return list(self._messages)
|
|
177
|
+
time.sleep(0.05)
|
|
178
|
+
|
|
179
|
+
def await_messages_on_topic(
|
|
180
|
+
self, topic: str, n: int, timeout: float = 5.0
|
|
181
|
+
) -> list[MQTTMessage]:
|
|
182
|
+
"""Wait until at least n messages on a specific topic, or timeout."""
|
|
183
|
+
deadline = time.time() + timeout
|
|
184
|
+
while True:
|
|
185
|
+
msgs = self.messages_on_topic(topic)
|
|
186
|
+
if len(msgs) >= n:
|
|
187
|
+
return msgs
|
|
188
|
+
if time.time() >= deadline:
|
|
189
|
+
return msgs
|
|
190
|
+
time.sleep(0.05)
|