matter-python-client 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,806 @@
1
+ """Matter Client implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ from typing import TYPE_CHECKING, Any, Final, cast
8
+ import uuid
9
+
10
+ from chip.clusters import Objects as Clusters
11
+ from chip.clusters.Types import NullValue
12
+
13
+ from matter_server.common.errors import NodeNotExists, exception_from_error_code
14
+
15
+ from ..common.helpers.util import (
16
+ convert_ip_address,
17
+ convert_mac_address,
18
+ create_attribute_path_from_attribute,
19
+ dataclass_from_dict,
20
+ dataclass_to_dict,
21
+ )
22
+ from ..common.models import (
23
+ APICommand,
24
+ CommandMessage,
25
+ CommissionableNodeData,
26
+ CommissioningParameters,
27
+ ErrorResultMessage,
28
+ EventMessage,
29
+ EventType,
30
+ MatterNodeData,
31
+ MatterNodeEvent,
32
+ MatterSoftwareVersion,
33
+ MessageType,
34
+ NodePingResult,
35
+ ResultMessageBase,
36
+ ServerDiagnostics,
37
+ ServerInfoMessage,
38
+ SuccessResultMessage,
39
+ )
40
+ from .connection import MatterClientConnection
41
+ from .exceptions import ConnectionClosed, InvalidState, ServerVersionTooOld
42
+ from .models.node import (
43
+ MatterFabricData,
44
+ MatterNode,
45
+ NetworkType,
46
+ NodeDiagnostics,
47
+ NodeType,
48
+ )
49
+
50
+ if TYPE_CHECKING:
51
+ from collections.abc import Callable
52
+ from types import TracebackType
53
+
54
+ from aiohttp import ClientSession
55
+ from chip.clusters.Objects import ClusterCommand
56
+
57
+ SUB_WILDCARD: Final = "*"
58
+
59
+ # pylint: disable=too-many-public-methods,too-many-locals,too-many-branches
60
+
61
+
62
+ class MatterClient:
63
+ """Manage a Matter server over WebSockets."""
64
+
65
+ def __init__(self, ws_server_url: str, aiohttp_session: ClientSession):
66
+ """Initialize the Client class."""
67
+ self.connection = MatterClientConnection(ws_server_url, aiohttp_session)
68
+ self.logger = logging.getLogger(__package__)
69
+ self._nodes: dict[int, MatterNode] = {}
70
+ self._result_futures: dict[str, asyncio.Future] = {}
71
+ self._subscribers: dict[str, list[Callable[[EventType, Any], None]]] = {}
72
+ self._stop_called: bool = False
73
+ self._loop: asyncio.AbstractEventLoop | None = None
74
+
75
+ @property
76
+ def server_info(self) -> ServerInfoMessage | None:
77
+ """Return info of the server we're currently connected to."""
78
+ return self.connection.server_info
79
+
80
+ def subscribe_events(
81
+ self,
82
+ callback: Callable[[EventType, Any], None],
83
+ event_filter: EventType | None = None,
84
+ node_filter: int | None = None,
85
+ attr_path_filter: str | None = None,
86
+ ) -> Callable[[], None]:
87
+ """
88
+ Subscribe to node and server events.
89
+
90
+ Optionally filter by specific events or node attributes.
91
+ Returns:
92
+ function to unsubscribe.
93
+
94
+ NOTE: To receive attribute changed events,
95
+ you must also register the attributes to subscribe to
96
+ with the `subscribe_attributes` method.
97
+ """
98
+ # for fast lookups we create a key based on the filters, allowing
99
+ # a "catch all" with a wildcard (*).
100
+ _event_filter: str
101
+ if event_filter is None:
102
+ _event_filter = SUB_WILDCARD
103
+ else:
104
+ _event_filter = event_filter.value
105
+
106
+ _node_filter: str
107
+ if node_filter is None:
108
+ _node_filter = SUB_WILDCARD
109
+ else:
110
+ _node_filter = str(node_filter)
111
+
112
+ if attr_path_filter is None:
113
+ attr_path_filter = SUB_WILDCARD
114
+
115
+ key = f"{_event_filter}/{_node_filter}/{attr_path_filter}"
116
+ self._subscribers.setdefault(key, [])
117
+ self._subscribers[key].append(callback)
118
+
119
+ def unsubscribe() -> None:
120
+ self._subscribers[key].remove(callback)
121
+
122
+ return unsubscribe
123
+
124
+ def get_nodes(self) -> list[MatterNode]:
125
+ """Return all Matter nodes."""
126
+ return list(self._nodes.values())
127
+
128
+ def get_node(self, node_id: int) -> MatterNode:
129
+ """Return Matter node by id or None if no node exists by that id."""
130
+ if node := self._nodes.get(node_id):
131
+ return node
132
+ raise NodeNotExists(f"Node {node_id} does not exist or is not yet interviewed")
133
+
134
+ async def set_default_fabric_label(self, label: str | None) -> None:
135
+ """Set the default fabric label."""
136
+ await self.send_command(
137
+ APICommand.SET_DEFAULT_FABRIC_LABEL, require_schema=11, label=label
138
+ )
139
+
140
+ async def commission_with_code(
141
+ self, code: str, network_only: bool = False
142
+ ) -> MatterNodeData:
143
+ """
144
+ Commission a device using a QR Code or Manual Pairing Code.
145
+
146
+ :param code: The QR Code or Manual Pairing Code for device commissioning.
147
+ :param network_only: If True, restricts device discovery to network only.
148
+
149
+ :return: The NodeInfo of the commissioned device.
150
+ """
151
+ data = await self.send_command(
152
+ APICommand.COMMISSION_WITH_CODE,
153
+ require_schema=6 if network_only else None,
154
+ code=code,
155
+ network_only=network_only,
156
+ )
157
+ return dataclass_from_dict(MatterNodeData, data)
158
+
159
+ async def commission_on_network(
160
+ self, setup_pin_code: int, ip_addr: str | None = None
161
+ ) -> MatterNodeData:
162
+ """
163
+ Do the routine for OnNetworkCommissioning.
164
+
165
+ NOTE: For advanced usecases only, use `commission_with_code`
166
+ for regular commissioning.
167
+
168
+ Returns basic MatterNodeData once complete.
169
+ """
170
+ data = await self.send_command(
171
+ APICommand.COMMISSION_ON_NETWORK,
172
+ require_schema=6 if ip_addr is not None else None,
173
+ setup_pin_code=setup_pin_code,
174
+ ip_addr=ip_addr,
175
+ )
176
+ return dataclass_from_dict(MatterNodeData, data)
177
+
178
+ async def set_wifi_credentials(self, ssid: str, credentials: str) -> None:
179
+ """Set WiFi credentials for commissioning to a (new) device."""
180
+ await self.send_command(
181
+ APICommand.SET_WIFI_CREDENTIALS, ssid=ssid, credentials=credentials
182
+ )
183
+
184
+ async def set_thread_operational_dataset(self, dataset: str) -> None:
185
+ """Set Thread Operational dataset in the stack."""
186
+ await self.send_command(APICommand.SET_THREAD_DATASET, dataset=dataset)
187
+
188
+ async def open_commissioning_window(
189
+ self,
190
+ node_id: int,
191
+ timeout: int = 300, # noqa: ASYNC109 timeout parameter required for native timeout
192
+ iteration: int = 1000,
193
+ option: int = 1,
194
+ discriminator: int | None = None,
195
+ ) -> CommissioningParameters:
196
+ """
197
+ Open a commissioning window to commission a device present on this controller to another.
198
+
199
+ Returns code to use as discriminator.
200
+ """
201
+ return dataclass_from_dict(
202
+ CommissioningParameters,
203
+ await self.send_command(
204
+ APICommand.OPEN_COMMISSIONING_WINDOW,
205
+ node_id=node_id,
206
+ timeout=timeout,
207
+ iteration=iteration,
208
+ option=option,
209
+ discriminator=discriminator,
210
+ ),
211
+ )
212
+
213
+ async def discover_commissionable_nodes(
214
+ self,
215
+ ) -> list[CommissionableNodeData]:
216
+ """Discover Commissionable Nodes (discovered on BLE or mDNS)."""
217
+ return [
218
+ dataclass_from_dict(CommissionableNodeData, x)
219
+ for x in await self.send_command(APICommand.DISCOVER, require_schema=7)
220
+ ]
221
+
222
+ async def get_matter_fabrics(self, node_id: int) -> list[MatterFabricData]:
223
+ """
224
+ Get Matter fabrics from a device.
225
+
226
+ Returns a list of MatterFabricData objects.
227
+ """
228
+
229
+ node = self.get_node(node_id)
230
+
231
+ # refresh node's fabrics if the node is available so we have the latest info
232
+ if node.available:
233
+ await self.refresh_attribute(
234
+ node_id,
235
+ create_attribute_path_from_attribute(
236
+ 0, Clusters.OperationalCredentials.Attributes.Fabrics
237
+ ),
238
+ )
239
+
240
+ fabrics: list[
241
+ Clusters.OperationalCredentials.Structs.FabricDescriptorStruct
242
+ ] = node.get_attribute_value(
243
+ 0, None, Clusters.OperationalCredentials.Attributes.Fabrics
244
+ )
245
+
246
+ vendors_map = await self.send_command(
247
+ APICommand.GET_VENDOR_NAMES,
248
+ require_schema=3,
249
+ filter_vendors=[f.vendorID for f in fabrics],
250
+ )
251
+
252
+ return [
253
+ MatterFabricData(
254
+ fabric_id=f.fabricID,
255
+ vendor_id=f.vendorID,
256
+ fabric_index=f.fabricIndex,
257
+ fabric_label=f.label if f.label else None,
258
+ vendor_name=vendors_map.get(str(f.vendorID)),
259
+ )
260
+ for f in fabrics
261
+ ]
262
+
263
+ async def remove_matter_fabric(self, node_id: int, fabric_index: int) -> None:
264
+ """Remove Matter fabric from a device."""
265
+ await self.send_device_command(
266
+ node_id,
267
+ 0,
268
+ Clusters.OperationalCredentials.Commands.RemoveFabric(
269
+ fabricIndex=fabric_index,
270
+ ),
271
+ )
272
+
273
+ async def ping_node(self, node_id: int) -> NodePingResult:
274
+ """Ping node on the currently known IP-adress(es)."""
275
+ return cast(
276
+ NodePingResult,
277
+ await self.send_command(APICommand.PING_NODE, node_id=node_id),
278
+ )
279
+
280
+ async def get_node_ip_addresses(
281
+ self, node_id: int, prefer_cache: bool = True, scoped: bool = False
282
+ ) -> list[str]:
283
+ """Return the currently known (scoped) IP-address(es)."""
284
+ if TYPE_CHECKING:
285
+ assert self.server_info is not None
286
+ if self.server_info.schema_version >= 8:
287
+ return cast(
288
+ list[str],
289
+ await self.send_command(
290
+ APICommand.GET_NODE_IP_ADDRESSES,
291
+ require_schema=8,
292
+ node_id=node_id,
293
+ prefer_cache=prefer_cache,
294
+ scoped=scoped,
295
+ ),
296
+ )
297
+ # alternative method of fetching ip addresses by enumerating NetworkInterfaces
298
+ node = self.get_node(node_id)
299
+ attribute = Clusters.GeneralDiagnostics.Attributes.NetworkInterfaces
300
+ network_interface: Clusters.GeneralDiagnostics.Structs.NetworkInterface
301
+ ip_addresses: list[str] = []
302
+ for network_interface in node.get_attribute_value(
303
+ 0, cluster=None, attribute=attribute
304
+ ):
305
+ # ignore invalid/non-operational interfaces
306
+ if not network_interface.isOperational:
307
+ continue
308
+ # enumerate ipv4 and ipv6 addresses
309
+ for ipv4_address_hex in network_interface.IPv4Addresses:
310
+ ipv4_address = convert_ip_address(ipv4_address_hex)
311
+ ip_addresses.append(ipv4_address)
312
+ for ipv6_address_hex in network_interface.IPv6Addresses:
313
+ ipv6_address = convert_ip_address(ipv6_address_hex, True)
314
+ ip_addresses.append(ipv6_address)
315
+ break
316
+ return ip_addresses
317
+
318
+ async def node_diagnostics(self, node_id: int) -> NodeDiagnostics:
319
+ """Gather diagnostics for the given node."""
320
+ # pylint: disable=too-many-statements
321
+ node = self.get_node(node_id)
322
+ ip_addresses = await self.get_node_ip_addresses(node_id)
323
+ # grab some details from the first (operational) network interface
324
+ network_type = NetworkType.UNKNOWN
325
+ mac_address = None
326
+ attribute = Clusters.GeneralDiagnostics.Attributes.NetworkInterfaces
327
+ network_interface: Clusters.GeneralDiagnostics.Structs.NetworkInterface
328
+ for network_interface in (
329
+ node.get_attribute_value(0, cluster=None, attribute=attribute) or []
330
+ ):
331
+ # ignore invalid/non-operational interfaces
332
+ if not network_interface.isOperational:
333
+ continue
334
+ if (
335
+ network_interface.type
336
+ == Clusters.GeneralDiagnostics.Enums.InterfaceTypeEnum.kThread
337
+ ):
338
+ network_type = NetworkType.THREAD
339
+ elif (
340
+ network_interface.type
341
+ == Clusters.GeneralDiagnostics.Enums.InterfaceTypeEnum.kWiFi
342
+ ):
343
+ network_type = NetworkType.WIFI
344
+ elif (
345
+ network_interface.type
346
+ == Clusters.GeneralDiagnostics.Enums.InterfaceTypeEnum.kEthernet
347
+ ):
348
+ network_type = NetworkType.ETHERNET
349
+ else:
350
+ # unknown interface: ignore
351
+ continue
352
+ mac_address = convert_mac_address(network_interface.hardwareAddress)
353
+ break
354
+ else:
355
+ self.logger.warning(
356
+ "Could not determine network_interface info for Node %s, "
357
+ "is it missing the GeneralDiagnostics/NetworkInterfaces Attribute?",
358
+ node_id,
359
+ )
360
+ # get thread/wifi specific info
361
+ node_type = NodeType.UNKNOWN
362
+ network_name = None
363
+ if network_type == NetworkType.THREAD:
364
+ thread_cluster: Clusters.ThreadNetworkDiagnostics = node.get_cluster(
365
+ 0, Clusters.ThreadNetworkDiagnostics
366
+ )
367
+ if thread_cluster:
368
+ if isinstance(thread_cluster.networkName, bytes):
369
+ network_name = thread_cluster.networkName.decode(
370
+ "utf-8", errors="replace"
371
+ )
372
+ elif thread_cluster.networkName != NullValue:
373
+ network_name = thread_cluster.networkName
374
+
375
+ # parse routing role to (diagnostics) node type
376
+ RoutingRole = Clusters.ThreadNetworkDiagnostics.Enums.RoutingRoleEnum # noqa: N806
377
+ if thread_cluster.routingRole == RoutingRole.kSleepyEndDevice:
378
+ node_type = NodeType.SLEEPY_END_DEVICE
379
+ elif thread_cluster.routingRole in (
380
+ RoutingRole.kLeader,
381
+ RoutingRole.kRouter,
382
+ ):
383
+ node_type = NodeType.ROUTING_END_DEVICE
384
+ elif thread_cluster.routingRole in (
385
+ RoutingRole.kEndDevice,
386
+ RoutingRole.kReed,
387
+ ):
388
+ node_type = NodeType.END_DEVICE
389
+ elif network_type == NetworkType.WIFI:
390
+ node_type = NodeType.END_DEVICE
391
+ # use lastNetworkID from NetworkCommissioning cluster as fallback to get the network name
392
+ # this allows getting the SSID as the wifi diagnostics cluster only has the BSSID
393
+ last_network_id: bytes | str | None
394
+ if not network_name and (
395
+ last_network_id := node.get_attribute_value(
396
+ 0,
397
+ cluster=None,
398
+ attribute=Clusters.NetworkCommissioning.Attributes.LastNetworkID,
399
+ )
400
+ ):
401
+ if isinstance(last_network_id, bytes):
402
+ network_name = last_network_id.decode("utf-8", errors="replace")
403
+ elif last_network_id != NullValue:
404
+ network_name = last_network_id
405
+ # last resort to get the (wifi) networkname;
406
+ # enumerate networks on the NetworkCommissioning cluster
407
+ networks: list[Clusters.NetworkCommissioning.Structs.NetworkInfoStruct]
408
+ if not network_name and (
409
+ networks := node.get_attribute_value(
410
+ 0,
411
+ cluster=None,
412
+ attribute=Clusters.NetworkCommissioning.Attributes.Networks,
413
+ )
414
+ ):
415
+ for network in networks:
416
+ if not network.connected:
417
+ continue
418
+ if isinstance(network.networkID, bytes):
419
+ network_name = network.networkID.decode("utf-8", errors="replace")
420
+ break
421
+ if network.networkID != NullValue:
422
+ network_name = network.networkID
423
+ break
424
+ # override node type if node is a bridge
425
+ if node.node_data.is_bridge:
426
+ node_type = NodeType.BRIDGE
427
+ # get active fabrics for this node
428
+ active_fabrics = await self.get_matter_fabrics(node_id)
429
+ # get active fabric index
430
+ fabric_index = node.get_attribute_value(
431
+ 0, None, Clusters.OperationalCredentials.Attributes.CurrentFabricIndex
432
+ )
433
+ return NodeDiagnostics(
434
+ node_id=node_id,
435
+ network_type=network_type,
436
+ node_type=node_type,
437
+ network_name=network_name,
438
+ ip_adresses=ip_addresses,
439
+ mac_address=mac_address,
440
+ available=node.available,
441
+ active_fabrics=active_fabrics,
442
+ active_fabric_index=fabric_index,
443
+ )
444
+
445
+ async def send_device_command(
446
+ self,
447
+ node_id: int,
448
+ endpoint_id: int,
449
+ command: ClusterCommand,
450
+ response_type: Any | None = None,
451
+ timed_request_timeout_ms: int | None = None,
452
+ interaction_timeout_ms: int | None = None,
453
+ ) -> Any:
454
+ """Send a command to a Matter node/device."""
455
+ try:
456
+ command_name = command.__class__.__name__
457
+ except AttributeError:
458
+ # handle case where only the class was provided instead of an instance of it.
459
+ command_name = command.__name__
460
+ return await self.send_command(
461
+ APICommand.DEVICE_COMMAND,
462
+ node_id=node_id,
463
+ endpoint_id=endpoint_id,
464
+ cluster_id=command.cluster_id,
465
+ command_name=command_name,
466
+ payload=dataclass_to_dict(command),
467
+ response_type=response_type,
468
+ timed_request_timeout_ms=timed_request_timeout_ms,
469
+ interaction_timeout_ms=interaction_timeout_ms,
470
+ )
471
+
472
+ async def read_attribute(
473
+ self,
474
+ node_id: int,
475
+ attribute_path: str | list[str],
476
+ ) -> dict[str, Any]:
477
+ """Read one or more attribute(s) on a node by specifying an attributepath."""
478
+ updated_values = await self.send_command(
479
+ APICommand.READ_ATTRIBUTE,
480
+ require_schema=9,
481
+ node_id=node_id,
482
+ attribute_path=attribute_path,
483
+ )
484
+ return cast(dict[str, Any], updated_values)
485
+
486
+ async def refresh_attribute(
487
+ self,
488
+ node_id: int,
489
+ attribute_path: str,
490
+ ) -> None:
491
+ """Read attribute(s) on a node and store the updated value(s)."""
492
+ updated_values = await self.read_attribute(node_id, attribute_path)
493
+ for attr_path, value in updated_values.items():
494
+ self._nodes[node_id].update_attribute(attr_path, value)
495
+
496
+ async def write_attribute(
497
+ self,
498
+ node_id: int,
499
+ attribute_path: str,
500
+ value: Any,
501
+ ) -> Any:
502
+ """Write an attribute(value) on a target node."""
503
+ return await self.send_command(
504
+ APICommand.WRITE_ATTRIBUTE,
505
+ require_schema=4,
506
+ node_id=node_id,
507
+ attribute_path=attribute_path,
508
+ value=value,
509
+ )
510
+
511
+ async def remove_node(self, node_id: int) -> None:
512
+ """Remove a Matter node/device from the fabric."""
513
+ await self.send_command(APICommand.REMOVE_NODE, node_id=node_id)
514
+
515
+ async def interview_node(self, node_id: int) -> None:
516
+ """Interview a node."""
517
+ await self.send_command(APICommand.INTERVIEW_NODE, node_id=node_id)
518
+
519
+ async def check_node_update(self, node_id: int) -> MatterSoftwareVersion | None:
520
+ """Check Node for updates.
521
+
522
+ Return a dict with the available update information. Most notable
523
+ "softwareVersion" contains the integer value of the update version which then
524
+ can be used for the update_node command to trigger the update.
525
+
526
+ The "softwareVersionString" is a human friendly version string.
527
+ """
528
+ data = await self.send_command(
529
+ APICommand.CHECK_NODE_UPDATE, node_id=node_id, require_schema=10
530
+ )
531
+ if data is None:
532
+ return None
533
+
534
+ return dataclass_from_dict(MatterSoftwareVersion, data)
535
+
536
+ async def update_node(
537
+ self,
538
+ node_id: int,
539
+ software_version: int | str,
540
+ ) -> None:
541
+ """Start node update to a particular version."""
542
+ await self.send_command(
543
+ APICommand.UPDATE_NODE,
544
+ node_id=node_id,
545
+ software_version=software_version,
546
+ require_schema=10,
547
+ )
548
+
549
+ def _prepare_message(
550
+ self,
551
+ command: str,
552
+ require_schema: int | None = None,
553
+ **kwargs: Any,
554
+ ) -> CommandMessage:
555
+ if not self.connection.connected:
556
+ raise InvalidState("Not connected")
557
+
558
+ if (
559
+ require_schema is not None
560
+ and self.server_info is not None
561
+ and require_schema > self.server_info.schema_version
562
+ ):
563
+ raise ServerVersionTooOld(
564
+ "Command not available due to incompatible server version. Update the Matter "
565
+ f"Server to a version that supports at least api schema {require_schema}.",
566
+ )
567
+
568
+ return CommandMessage(
569
+ message_id=uuid.uuid4().hex,
570
+ command=command,
571
+ args=kwargs,
572
+ )
573
+
574
+ async def send_command(
575
+ self,
576
+ command: str,
577
+ require_schema: int | None = None,
578
+ **kwargs: Any,
579
+ ) -> Any:
580
+ """Send a command and get a response."""
581
+ if not self._loop:
582
+ raise InvalidState("Not connected")
583
+
584
+ message = self._prepare_message(command, require_schema, **kwargs)
585
+ future: asyncio.Future[Any] = self._loop.create_future()
586
+ self._result_futures[message.message_id] = future
587
+ await self.connection.send_message(message)
588
+ try:
589
+ return await future
590
+ finally:
591
+ self._result_futures.pop(message.message_id)
592
+
593
+ async def send_command_no_wait(
594
+ self,
595
+ command: str,
596
+ require_schema: int | None = None,
597
+ **kwargs: Any,
598
+ ) -> None:
599
+ """Send a command without waiting for the response."""
600
+
601
+ message = self._prepare_message(command, require_schema, **kwargs)
602
+ await self.connection.send_message(message)
603
+
604
+ async def get_diagnostics(self) -> ServerDiagnostics:
605
+ """Return a full dump of the server (for diagnostics)."""
606
+ data = await self.send_command(APICommand.SERVER_DIAGNOSTICS)
607
+ return dataclass_from_dict(ServerDiagnostics, data)
608
+
609
+ async def connect(self) -> None:
610
+ """Connect to the Matter Server (over Websockets)."""
611
+ self._loop = asyncio.get_running_loop()
612
+ if self.connection.connected:
613
+ # already connected
614
+ return
615
+
616
+ self._stop_called = False
617
+ # NOTE: connect will raise when connecting failed
618
+ await self.connection.connect()
619
+
620
+ async def start_listening(self, init_ready: asyncio.Event | None = None) -> None:
621
+ """Start listening to the websocket (and receive initial state)."""
622
+ await self.connect()
623
+
624
+ try:
625
+ message = CommandMessage(
626
+ message_id=uuid.uuid4().hex, command=APICommand.START_LISTENING
627
+ )
628
+ await self.connection.send_message(message)
629
+ nodes_msg = cast(
630
+ SuccessResultMessage, await self.connection.receive_message_or_raise()
631
+ )
632
+ # a full dump of all nodes will be the result of the start_listening command
633
+ # create MatterNode objects from the basic MatterNodeData objects
634
+ nodes = [
635
+ MatterNode(dataclass_from_dict(MatterNodeData, x))
636
+ for x in nodes_msg.result
637
+ ]
638
+ self._nodes = {node.node_id: node for node in nodes}
639
+ # once we've hit this point we're all set
640
+ self.logger.info("Matter client initialized.")
641
+ if init_ready is not None:
642
+ init_ready.set()
643
+
644
+ # keep reading incoming messages
645
+ while not self._stop_called:
646
+ msg = await self.connection.receive_message_or_raise()
647
+ self._handle_incoming_message(msg)
648
+ except ConnectionClosed:
649
+ pass
650
+ finally:
651
+ await self.disconnect()
652
+
653
+ async def disconnect(self) -> None:
654
+ """Disconnect the client and cleanup."""
655
+ self._stop_called = True
656
+ # cancel all command-tasks awaiting a result
657
+ for future in self._result_futures.values():
658
+ future.cancel()
659
+ await self.connection.disconnect()
660
+
661
+ def _handle_incoming_message(self, msg: MessageType) -> None:
662
+ """
663
+ Handle incoming message.
664
+
665
+ Run all async tasks in a wrapper to log appropriately.
666
+ """
667
+ # handle result message
668
+ if isinstance(msg, ResultMessageBase):
669
+ future = self._result_futures.get(msg.message_id)
670
+
671
+ if future is None:
672
+ # no listener for this result
673
+ return
674
+
675
+ if isinstance(msg, SuccessResultMessage):
676
+ future.set_result(msg.result)
677
+ return
678
+ if isinstance(msg, ErrorResultMessage):
679
+ exc = exception_from_error_code(msg.error_code)
680
+ future.set_exception(exc(msg.details))
681
+ return
682
+
683
+ # handle EventMessage
684
+ if isinstance(msg, EventMessage):
685
+ self._handle_event_message(msg)
686
+ return
687
+
688
+ # Log anything we can't handle here
689
+ self.logger.debug(
690
+ "Received message with unknown type '%s': %s",
691
+ type(msg),
692
+ msg,
693
+ )
694
+
695
+ def _handle_event_message(self, msg: EventMessage) -> None:
696
+ """Handle incoming event from the server."""
697
+ if msg.event in (EventType.NODE_ADDED, EventType.NODE_UPDATED):
698
+ # an update event can potentially arrive for a not yet known node
699
+ node_data = dataclass_from_dict(MatterNodeData, msg.data)
700
+ node = self._nodes.get(node_data.node_id)
701
+ if node is None:
702
+ event = EventType.NODE_ADDED
703
+ node = MatterNode(node_data)
704
+ self._nodes[node.node_id] = node
705
+ self.logger.debug("New node added: %s", node.node_id)
706
+ else:
707
+ event = EventType.NODE_UPDATED
708
+ node.update(node_data)
709
+ self.logger.debug("Node updated: %s", node.node_id)
710
+ self._signal_event(event, data=node, node_id=node.node_id)
711
+ return
712
+ if msg.event == EventType.NODE_REMOVED:
713
+ node_id = msg.data
714
+ self.logger.debug("Node removed: %s", node_id)
715
+ self._signal_event(EventType.NODE_REMOVED, data=node_id, node_id=node_id)
716
+ # cleanup node only after signalling subscribers
717
+ self._nodes.pop(node_id, None)
718
+ return
719
+ if msg.event == EventType.ENDPOINT_REMOVED:
720
+ node_id = msg.data["node_id"]
721
+ endpoint_id = msg.data["endpoint_id"]
722
+ self.logger.debug("Endpoint removed: %s/%s", node_id, endpoint_id)
723
+ self._signal_event(
724
+ EventType.ENDPOINT_REMOVED, data=msg.data, node_id=node_id
725
+ )
726
+ # cleanup endpoint only after signalling subscribers
727
+ if node := self._nodes.get(node_id):
728
+ node.endpoints.pop(endpoint_id, None)
729
+ return
730
+ if msg.event == EventType.ATTRIBUTE_UPDATED:
731
+ # data is tuple[node_id, attribute_path, new_value]
732
+ node_id, attribute_path, new_value = msg.data
733
+ if self.logger.isEnabledFor(logging.DEBUG):
734
+ self.logger.debug(
735
+ "Attribute updated: Node: %s - Attribute: %s - New value: %s",
736
+ node_id,
737
+ attribute_path,
738
+ new_value,
739
+ )
740
+ self._nodes[node_id].update_attribute(attribute_path, new_value)
741
+ self._signal_event(
742
+ EventType.ATTRIBUTE_UPDATED,
743
+ data=new_value,
744
+ node_id=node_id,
745
+ attribute_path=attribute_path,
746
+ )
747
+ return
748
+ if msg.event == EventType.ENDPOINT_ADDED:
749
+ node_id = msg.data["node_id"]
750
+ endpoint_id = msg.data["endpoint_id"]
751
+ self.logger.debug("Endpoint added: %s/%s", node_id, endpoint_id)
752
+ if msg.event == EventType.NODE_EVENT:
753
+ if self.logger.isEnabledFor(logging.DEBUG):
754
+ self.logger.debug(
755
+ "Node event: %s",
756
+ msg.data,
757
+ )
758
+ node_event = dataclass_from_dict(MatterNodeEvent, msg.data)
759
+ self._signal_event(
760
+ EventType.NODE_EVENT,
761
+ data=node_event,
762
+ node_id=node_event.node_id,
763
+ )
764
+ return
765
+ # simply forward all other events as-is
766
+ if self.logger.isEnabledFor(logging.DEBUG):
767
+ self.logger.debug("Received event: %s", msg)
768
+ self._signal_event(msg.event, msg.data)
769
+
770
+ def _signal_event(
771
+ self,
772
+ event: EventType,
773
+ data: Any = None,
774
+ node_id: int | None = None,
775
+ attribute_path: str | None = None,
776
+ ) -> None:
777
+ """Signal event to all subscribers."""
778
+ # instead of iterating all subscribers we iterate over subscription keys
779
+ # each callback is stored under a specific key based on the filters
780
+ for evt_key in (event.value, SUB_WILDCARD):
781
+ for node_key in (node_id, SUB_WILDCARD):
782
+ if node_key is None:
783
+ continue
784
+ for attribute_path_key in (attribute_path, SUB_WILDCARD):
785
+ if attribute_path_key is None:
786
+ continue
787
+ key = f"{evt_key}/{node_key}/{attribute_path_key}"
788
+ for callback in self._subscribers.get(key, []):
789
+ callback(event, data)
790
+
791
+ async def __aenter__(self) -> "MatterClient":
792
+ """Initialize and connect the Matter Websocket client."""
793
+ await self.connect()
794
+ return self
795
+
796
+ async def __aexit__(
797
+ self, exc_type: Exception, exc_value: str, traceback: TracebackType
798
+ ) -> None:
799
+ """Disconnect from the websocket."""
800
+ await self.disconnect()
801
+
802
+ def __repr__(self) -> str:
803
+ """Return the representation."""
804
+ url = self.connection.ws_server_url
805
+ prefix = "" if self.connection.connected else "not "
806
+ return f"{type(self).__name__}(ws_server_url={url!r}, {prefix}connected)"