bumble 0.0.180__py3-none-any.whl → 0.0.181__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bumble/hfp.py CHANGED
@@ -22,7 +22,7 @@ import dataclasses
22
22
  import enum
23
23
  import traceback
24
24
  import warnings
25
- from typing import Dict, List, Union, Set, TYPE_CHECKING
25
+ from typing import Dict, List, Union, Set, Any, TYPE_CHECKING
26
26
 
27
27
  from . import at
28
28
  from . import rfcomm
@@ -35,7 +35,11 @@ from bumble.core import (
35
35
  BT_L2CAP_PROTOCOL_ID,
36
36
  BT_RFCOMM_PROTOCOL_ID,
37
37
  )
38
- from bumble.hci import HCI_Enhanced_Setup_Synchronous_Connection_Command
38
+ from bumble.hci import (
39
+ HCI_Enhanced_Setup_Synchronous_Connection_Command,
40
+ CodingFormat,
41
+ CodecID,
42
+ )
39
43
  from bumble.sdp import (
40
44
  DataElement,
41
45
  ServiceAttribute,
@@ -66,6 +70,7 @@ class HfpProtocolError(ProtocolError):
66
70
  # Protocol Support
67
71
  # -----------------------------------------------------------------------------
68
72
 
73
+
69
74
  # -----------------------------------------------------------------------------
70
75
  class HfpProtocol:
71
76
  dlc: rfcomm.DLC
@@ -842,19 +847,15 @@ class DefaultCodecParameters(enum.IntEnum):
842
847
  @dataclasses.dataclass
843
848
  class EscoParameters:
844
849
  # Codec specific
845
- transmit_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat
846
- receive_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat
850
+ transmit_coding_format: CodingFormat
851
+ receive_coding_format: CodingFormat
847
852
  packet_type: HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType
848
853
  retransmission_effort: HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort
849
854
  max_latency: int
850
855
 
851
856
  # Common
852
- input_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat = (
853
- HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.TRANSPARENT
854
- )
855
- output_coding_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat = (
856
- HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.TRANSPARENT
857
- )
857
+ input_coding_format: CodingFormat = CodingFormat(CodecID.LINEAR_PCM)
858
+ output_coding_format: CodingFormat = CodingFormat(CodecID.LINEAR_PCM)
858
859
  input_coded_data_size: int = 16
859
860
  output_coded_data_size: int = 16
860
861
  input_pcm_data_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat = (
@@ -880,26 +881,31 @@ class EscoParameters:
880
881
  transmit_codec_frame_size: int = 60
881
882
  receive_codec_frame_size: int = 60
882
883
 
884
+ def asdict(self) -> Dict[str, Any]:
885
+ # dataclasses.asdict() will recursively deep-copy the entire object,
886
+ # which is expensive and breaks CodingFormat object, so let it simply copy here.
887
+ return self.__dict__
888
+
883
889
 
884
890
  _ESCO_PARAMETERS_CVSD_D0 = EscoParameters(
885
- transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
886
- receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
891
+ transmit_coding_format=CodingFormat(CodecID.CVSD),
892
+ receive_coding_format=CodingFormat(CodecID.CVSD),
887
893
  max_latency=0xFFFF,
888
894
  packet_type=HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.HV1,
889
895
  retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.NO_RETRANSMISSION,
890
896
  )
891
897
 
892
898
  _ESCO_PARAMETERS_CVSD_D1 = EscoParameters(
893
- transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
894
- receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
899
+ transmit_coding_format=CodingFormat(CodecID.CVSD),
900
+ receive_coding_format=CodingFormat(CodecID.CVSD),
895
901
  max_latency=0xFFFF,
896
902
  packet_type=HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.HV3,
897
903
  retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.NO_RETRANSMISSION,
898
904
  )
899
905
 
900
906
  _ESCO_PARAMETERS_CVSD_S1 = EscoParameters(
901
- transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
902
- receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
907
+ transmit_coding_format=CodingFormat(CodecID.CVSD),
908
+ receive_coding_format=CodingFormat(CodecID.CVSD),
903
909
  max_latency=0x0007,
904
910
  packet_type=(
905
911
  HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
@@ -912,8 +918,8 @@ _ESCO_PARAMETERS_CVSD_S1 = EscoParameters(
912
918
  )
913
919
 
914
920
  _ESCO_PARAMETERS_CVSD_S2 = EscoParameters(
915
- transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
916
- receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
921
+ transmit_coding_format=CodingFormat(CodecID.CVSD),
922
+ receive_coding_format=CodingFormat(CodecID.CVSD),
917
923
  max_latency=0x0007,
918
924
  packet_type=(
919
925
  HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
@@ -925,8 +931,8 @@ _ESCO_PARAMETERS_CVSD_S2 = EscoParameters(
925
931
  )
926
932
 
927
933
  _ESCO_PARAMETERS_CVSD_S3 = EscoParameters(
928
- transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
929
- receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
934
+ transmit_coding_format=CodingFormat(CodecID.CVSD),
935
+ receive_coding_format=CodingFormat(CodecID.CVSD),
930
936
  max_latency=0x000A,
931
937
  packet_type=(
932
938
  HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
@@ -938,8 +944,8 @@ _ESCO_PARAMETERS_CVSD_S3 = EscoParameters(
938
944
  )
939
945
 
940
946
  _ESCO_PARAMETERS_CVSD_S4 = EscoParameters(
941
- transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
942
- receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.CVSD,
947
+ transmit_coding_format=CodingFormat(CodecID.CVSD),
948
+ receive_coding_format=CodingFormat(CodecID.CVSD),
943
949
  max_latency=0x000C,
944
950
  packet_type=(
945
951
  HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
@@ -951,8 +957,8 @@ _ESCO_PARAMETERS_CVSD_S4 = EscoParameters(
951
957
  )
952
958
 
953
959
  _ESCO_PARAMETERS_MSBC_T1 = EscoParameters(
954
- transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.MSBC,
955
- receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.MSBC,
960
+ transmit_coding_format=CodingFormat(CodecID.MSBC),
961
+ receive_coding_format=CodingFormat(CodecID.MSBC),
956
962
  max_latency=0x0008,
957
963
  packet_type=(
958
964
  HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
@@ -960,12 +966,14 @@ _ESCO_PARAMETERS_MSBC_T1 = EscoParameters(
960
966
  | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
961
967
  | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
962
968
  ),
969
+ input_bandwidth=32000,
970
+ output_bandwidth=32000,
963
971
  retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_QUALITY,
964
972
  )
965
973
 
966
974
  _ESCO_PARAMETERS_MSBC_T2 = EscoParameters(
967
- transmit_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.MSBC,
968
- receive_coding_format=HCI_Enhanced_Setup_Synchronous_Connection_Command.CodingFormat.MSBC,
975
+ transmit_coding_format=CodingFormat(CodecID.MSBC),
976
+ receive_coding_format=CodingFormat(CodecID.MSBC),
969
977
  max_latency=0x000D,
970
978
  packet_type=(
971
979
  HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
@@ -974,10 +982,12 @@ _ESCO_PARAMETERS_MSBC_T2 = EscoParameters(
974
982
  | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
975
983
  | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
976
984
  ),
985
+ input_bandwidth=32000,
986
+ output_bandwidth=32000,
977
987
  retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_QUALITY,
978
988
  )
979
989
 
980
- ESCO_PERAMETERS = {
990
+ ESCO_PARAMETERS = {
981
991
  DefaultCodecParameters.SCO_CVSD_D0: _ESCO_PARAMETERS_CVSD_D0,
982
992
  DefaultCodecParameters.SCO_CVSD_D1: _ESCO_PARAMETERS_CVSD_D1,
983
993
  DefaultCodecParameters.ESCO_CVSD_S1: _ESCO_PARAMETERS_CVSD_S1,
bumble/hid.py CHANGED
@@ -19,16 +19,17 @@ from __future__ import annotations
19
19
  from dataclasses import dataclass
20
20
  import logging
21
21
  import enum
22
+ import struct
22
23
 
24
+ from abc import ABC, abstractmethod
23
25
  from pyee import EventEmitter
24
- from typing import Optional, TYPE_CHECKING
26
+ from typing import Optional, Callable, TYPE_CHECKING
27
+ from typing_extensions import override
25
28
 
26
- from bumble import l2cap
29
+ from bumble import l2cap, device
27
30
  from bumble.colors import color
28
31
  from bumble.core import InvalidStateError, ProtocolError
29
-
30
- if TYPE_CHECKING:
31
- from bumble.device import Device, Connection
32
+ from .hci import Address
32
33
 
33
34
 
34
35
  # -----------------------------------------------------------------------------
@@ -60,6 +61,7 @@ class Message:
60
61
  NOT_READY = 0x01
61
62
  ERR_INVALID_REPORT_ID = 0x02
62
63
  ERR_UNSUPPORTED_REQUEST = 0x03
64
+ ERR_INVALID_PARAMETER = 0x04
63
65
  ERR_UNKNOWN = 0x0E
64
66
  ERR_FATAL = 0x0F
65
67
 
@@ -101,13 +103,14 @@ class GetReportMessage(Message):
101
103
  def __bytes__(self) -> bytes:
102
104
  packet_bytes = bytearray()
103
105
  packet_bytes.append(self.report_id)
104
- packet_bytes.extend(
105
- [(self.buffer_size & 0xFF), ((self.buffer_size >> 8) & 0xFF)]
106
- )
107
- if self.report_type == Message.ReportType.OTHER_REPORT:
106
+ if self.buffer_size == 0:
108
107
  return self.header(self.report_type) + packet_bytes
109
108
  else:
110
- return self.header(0x08 | self.report_type) + packet_bytes
109
+ return (
110
+ self.header(0x08 | self.report_type)
111
+ + packet_bytes
112
+ + struct.pack("<H", self.buffer_size)
113
+ )
111
114
 
112
115
 
113
116
  @dataclass
@@ -120,6 +123,16 @@ class SetReportMessage(Message):
120
123
  return self.header(self.report_type) + self.data
121
124
 
122
125
 
126
+ @dataclass
127
+ class SendControlData(Message):
128
+ report_type: int
129
+ data: bytes
130
+ message_type = Message.MessageType.DATA
131
+
132
+ def __bytes__(self) -> bytes:
133
+ return self.header(self.report_type) + self.data
134
+
135
+
123
136
  @dataclass
124
137
  class GetProtocolMessage(Message):
125
138
  message_type = Message.MessageType.GET_PROTOCOL
@@ -161,31 +174,47 @@ class VirtualCableUnplug(Message):
161
174
  return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
162
175
 
163
176
 
177
+ # Device sends input report, host sends output report.
164
178
  @dataclass
165
179
  class SendData(Message):
166
180
  data: bytes
181
+ report_type: int
167
182
  message_type = Message.MessageType.DATA
168
183
 
169
184
  def __bytes__(self) -> bytes:
170
- return self.header(Message.ReportType.OUTPUT_REPORT) + self.data
185
+ return self.header(self.report_type) + self.data
186
+
187
+
188
+ @dataclass
189
+ class SendHandshakeMessage(Message):
190
+ result_code: int
191
+ message_type = Message.MessageType.HANDSHAKE
192
+
193
+ def __bytes__(self) -> bytes:
194
+ return self.header(self.result_code)
171
195
 
172
196
 
173
197
  # -----------------------------------------------------------------------------
174
- class Host(EventEmitter):
175
- l2cap_ctrl_channel: Optional[l2cap.ClassicChannel]
176
- l2cap_intr_channel: Optional[l2cap.ClassicChannel]
198
+ class HID(ABC, EventEmitter):
199
+ l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
200
+ l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
201
+ connection: Optional[device.Connection] = None
202
+
203
+ class Role(enum.IntEnum):
204
+ HOST = 0x00
205
+ DEVICE = 0x01
177
206
 
178
- def __init__(self, device: Device, connection: Connection) -> None:
207
+ def __init__(self, device: device.Device, role: Role) -> None:
179
208
  super().__init__()
209
+ self.remote_device_bd_address: Optional[Address] = None
180
210
  self.device = device
181
- self.connection = connection
182
-
183
- self.l2cap_ctrl_channel = None
184
- self.l2cap_intr_channel = None
211
+ self.role = role
185
212
 
186
213
  # Register ourselves with the L2CAP channel manager
187
- device.register_l2cap_server(HID_CONTROL_PSM, self.on_connection)
188
- device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_connection)
214
+ device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
215
+ device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
216
+
217
+ device.on('connection', self.on_device_connection)
189
218
 
190
219
  async def connect_control_channel(self) -> None:
191
220
  # Create a new L2CAP connection - control channel
@@ -229,9 +258,18 @@ class Host(EventEmitter):
229
258
  self.l2cap_ctrl_channel = None
230
259
  await channel.disconnect()
231
260
 
232
- def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
261
+ def on_device_connection(self, connection: device.Connection) -> None:
262
+ self.connection = connection
263
+ self.remote_device_bd_address = connection.peer_address
264
+ connection.on('disconnection', self.on_device_disconnection)
265
+
266
+ def on_device_disconnection(self, reason: int) -> None:
267
+ self.connection = None
268
+
269
+ def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
233
270
  logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
234
271
  l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
272
+ l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
235
273
 
236
274
  def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
237
275
  if l2cap_channel.psm == HID_CONTROL_PSM:
@@ -242,37 +280,220 @@ class Host(EventEmitter):
242
280
  self.l2cap_intr_channel.sink = self.on_intr_pdu
243
281
  logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
244
282
 
283
+ def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
284
+ if l2cap_channel.psm == HID_CONTROL_PSM:
285
+ self.l2cap_ctrl_channel = None
286
+ else:
287
+ self.l2cap_intr_channel = None
288
+ logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
289
+
290
+ @abstractmethod
291
+ def on_ctrl_pdu(self, pdu: bytes) -> None:
292
+ pass
293
+
294
+ def on_intr_pdu(self, pdu: bytes) -> None:
295
+ logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
296
+ self.emit("interrupt_data", pdu)
297
+
298
+ def send_pdu_on_ctrl(self, msg: bytes) -> None:
299
+ assert self.l2cap_ctrl_channel
300
+ self.l2cap_ctrl_channel.send_pdu(msg)
301
+
302
+ def send_pdu_on_intr(self, msg: bytes) -> None:
303
+ assert self.l2cap_intr_channel
304
+ self.l2cap_intr_channel.send_pdu(msg)
305
+
306
+ def send_data(self, data: bytes) -> None:
307
+ if self.role == HID.Role.HOST:
308
+ report_type = Message.ReportType.OUTPUT_REPORT
309
+ else:
310
+ report_type = Message.ReportType.INPUT_REPORT
311
+ msg = SendData(data, report_type)
312
+ hid_message = bytes(msg)
313
+ if self.l2cap_intr_channel is not None:
314
+ logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
315
+ self.send_pdu_on_intr(hid_message)
316
+
317
+ def virtual_cable_unplug(self) -> None:
318
+ msg = VirtualCableUnplug()
319
+ hid_message = bytes(msg)
320
+ logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
321
+ self.send_pdu_on_ctrl(hid_message)
322
+
323
+
324
+ # -----------------------------------------------------------------------------
325
+
326
+
327
+ class Device(HID):
328
+ class GetSetReturn(enum.IntEnum):
329
+ FAILURE = 0x00
330
+ REPORT_ID_NOT_FOUND = 0x01
331
+ ERR_UNSUPPORTED_REQUEST = 0x02
332
+ ERR_UNKNOWN = 0x03
333
+ ERR_INVALID_PARAMETER = 0x04
334
+ SUCCESS = 0xFF
335
+
336
+ class GetSetStatus:
337
+ def __init__(self) -> None:
338
+ self.data = bytearray()
339
+ self.status = 0
340
+
341
+ def __init__(self, device: device.Device) -> None:
342
+ super().__init__(device, HID.Role.DEVICE)
343
+ get_report_cb: Optional[Callable[[int, int, int], None]] = None
344
+ set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
345
+ get_protocol_cb: Optional[Callable[[], None]] = None
346
+ set_protocol_cb: Optional[Callable[[int], None]] = None
347
+
348
+ @override
245
349
  def on_ctrl_pdu(self, pdu: bytes) -> None:
246
350
  logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
247
- # Here we will receive all kinds of packets, parse and then call respective callbacks
248
- message_type = pdu[0] >> 4
249
351
  param = pdu[0] & 0x0F
352
+ message_type = pdu[0] >> 4
250
353
 
251
- if message_type == Message.MessageType.HANDSHAKE:
252
- logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
253
- self.emit('handshake', Message.Handshake(param))
354
+ if message_type == Message.MessageType.GET_REPORT:
355
+ logger.debug('<<< HID GET REPORT')
356
+ self.handle_get_report(pdu)
357
+ elif message_type == Message.MessageType.SET_REPORT:
358
+ logger.debug('<<< HID SET REPORT')
359
+ self.handle_set_report(pdu)
360
+ elif message_type == Message.MessageType.GET_PROTOCOL:
361
+ logger.debug('<<< HID GET PROTOCOL')
362
+ self.handle_get_protocol(pdu)
363
+ elif message_type == Message.MessageType.SET_PROTOCOL:
364
+ logger.debug('<<< HID SET PROTOCOL')
365
+ self.handle_set_protocol(pdu)
254
366
  elif message_type == Message.MessageType.DATA:
255
367
  logger.debug('<<< HID CONTROL DATA')
256
- self.emit('data', pdu)
368
+ self.emit('control_data', pdu)
257
369
  elif message_type == Message.MessageType.CONTROL:
258
370
  if param == Message.ControlCommand.SUSPEND:
259
371
  logger.debug('<<< HID SUSPEND')
260
- self.emit('suspend', pdu)
372
+ self.emit('suspend')
261
373
  elif param == Message.ControlCommand.EXIT_SUSPEND:
262
374
  logger.debug('<<< HID EXIT SUSPEND')
263
- self.emit('exit_suspend', pdu)
375
+ self.emit('exit_suspend')
264
376
  elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
265
377
  logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
266
378
  self.emit('virtual_cable_unplug')
267
379
  else:
268
380
  logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
269
381
  else:
270
- logger.debug('<<< HID CONTROL DATA')
271
- self.emit('data', pdu)
382
+ logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
383
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
272
384
 
273
- def on_intr_pdu(self, pdu: bytes) -> None:
274
- logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
275
- self.emit("data", pdu)
385
+ def send_handshake_message(self, result_code: int) -> None:
386
+ msg = SendHandshakeMessage(result_code)
387
+ hid_message = bytes(msg)
388
+ logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
389
+ self.send_pdu_on_ctrl(hid_message)
390
+
391
+ def send_control_data(self, report_type: int, data: bytes):
392
+ msg = SendControlData(report_type=report_type, data=data)
393
+ hid_message = bytes(msg)
394
+ logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
395
+ self.send_pdu_on_ctrl(hid_message)
396
+
397
+ def handle_get_report(self, pdu: bytes):
398
+ if self.get_report_cb is None:
399
+ logger.debug("GetReport callback not registered !!")
400
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
401
+ return
402
+ report_type = pdu[0] & 0x03
403
+ buffer_flag = (pdu[0] & 0x08) >> 3
404
+ report_id = pdu[1]
405
+ logger.debug(f"buffer_flag: {buffer_flag}")
406
+ if buffer_flag == 1:
407
+ buffer_size = (pdu[3] << 8) | pdu[2]
408
+ else:
409
+ buffer_size = 0
410
+
411
+ ret = self.get_report_cb(report_id, report_type, buffer_size)
412
+ assert ret is not None
413
+ if ret.status == self.GetSetReturn.FAILURE:
414
+ self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
415
+ elif ret.status == self.GetSetReturn.SUCCESS:
416
+ data = bytearray()
417
+ data.append(report_id)
418
+ data.extend(ret.data)
419
+ if len(data) < self.l2cap_ctrl_channel.mtu: # type: ignore[union-attr]
420
+ self.send_control_data(report_type=report_type, data=data)
421
+ else:
422
+ self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
423
+ elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
424
+ self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
425
+ elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
426
+ self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
427
+ elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
428
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
429
+
430
+ def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
431
+ self.get_report_cb = cb
432
+ logger.debug("GetReport callback registered successfully")
433
+
434
+ def handle_set_report(self, pdu: bytes):
435
+ if self.set_report_cb is None:
436
+ logger.debug("SetReport callback not registered !!")
437
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
438
+ return
439
+ report_type = pdu[0] & 0x03
440
+ report_id = pdu[1]
441
+ report_data = pdu[2:]
442
+ report_size = len(report_data) + 1
443
+ ret = self.set_report_cb(report_id, report_type, report_size, report_data)
444
+ assert ret is not None
445
+ if ret.status == self.GetSetReturn.SUCCESS:
446
+ self.send_handshake_message(Message.Handshake.SUCCESSFUL)
447
+ elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
448
+ self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
449
+ elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
450
+ self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
451
+ else:
452
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
453
+
454
+ def register_set_report_cb(
455
+ self, cb: Callable[[int, int, int, bytes], None]
456
+ ) -> None:
457
+ self.set_report_cb = cb
458
+ logger.debug("SetReport callback registered successfully")
459
+
460
+ def handle_get_protocol(self, pdu: bytes):
461
+ if self.get_protocol_cb is None:
462
+ logger.debug("GetProtocol callback not registered !!")
463
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
464
+ return
465
+ ret = self.get_protocol_cb()
466
+ assert ret is not None
467
+ if ret.status == self.GetSetReturn.SUCCESS:
468
+ self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
469
+ else:
470
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
471
+
472
+ def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
473
+ self.get_protocol_cb = cb
474
+ logger.debug("GetProtocol callback registered successfully")
475
+
476
+ def handle_set_protocol(self, pdu: bytes):
477
+ if self.set_protocol_cb is None:
478
+ logger.debug("SetProtocol callback not registered !!")
479
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
480
+ return
481
+ ret = self.set_protocol_cb(pdu[0] & 0x01)
482
+ assert ret is not None
483
+ if ret.status == self.GetSetReturn.SUCCESS:
484
+ self.send_handshake_message(Message.Handshake.SUCCESSFUL)
485
+ else:
486
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
487
+
488
+ def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
489
+ self.set_protocol_cb = cb
490
+ logger.debug("SetProtocol callback registered successfully")
491
+
492
+
493
+ # -----------------------------------------------------------------------------
494
+ class Host(HID):
495
+ def __init__(self, device: device.Device) -> None:
496
+ super().__init__(device, HID.Role.HOST)
276
497
 
277
498
  def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
278
499
  msg = GetReportMessage(
@@ -282,52 +503,52 @@ class Host(EventEmitter):
282
503
  logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
283
504
  self.send_pdu_on_ctrl(hid_message)
284
505
 
285
- def set_report(self, report_type: int, data: bytes):
506
+ def set_report(self, report_type: int, data: bytes) -> None:
286
507
  msg = SetReportMessage(report_type=report_type, data=data)
287
508
  hid_message = bytes(msg)
288
509
  logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
289
510
  self.send_pdu_on_ctrl(hid_message)
290
511
 
291
- def get_protocol(self):
512
+ def get_protocol(self) -> None:
292
513
  msg = GetProtocolMessage()
293
514
  hid_message = bytes(msg)
294
515
  logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
295
516
  self.send_pdu_on_ctrl(hid_message)
296
517
 
297
- def set_protocol(self, protocol_mode: int):
518
+ def set_protocol(self, protocol_mode: int) -> None:
298
519
  msg = SetProtocolMessage(protocol_mode=protocol_mode)
299
520
  hid_message = bytes(msg)
300
521
  logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
301
522
  self.send_pdu_on_ctrl(hid_message)
302
523
 
303
- def send_pdu_on_ctrl(self, msg: bytes) -> None:
304
- assert self.l2cap_ctrl_channel
305
- self.l2cap_ctrl_channel.send_pdu(msg)
306
-
307
- def send_pdu_on_intr(self, msg: bytes) -> None:
308
- assert self.l2cap_intr_channel
309
- self.l2cap_intr_channel.send_pdu(msg)
310
-
311
- def send_data(self, data):
312
- msg = SendData(data)
313
- hid_message = bytes(msg)
314
- logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
315
- self.send_pdu_on_intr(hid_message)
316
-
317
- def suspend(self):
524
+ def suspend(self) -> None:
318
525
  msg = Suspend()
319
526
  hid_message = bytes(msg)
320
527
  logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
321
- self.send_pdu_on_ctrl(msg)
528
+ self.send_pdu_on_ctrl(hid_message)
322
529
 
323
- def exit_suspend(self):
530
+ def exit_suspend(self) -> None:
324
531
  msg = ExitSuspend()
325
532
  hid_message = bytes(msg)
326
533
  logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
327
- self.send_pdu_on_ctrl(msg)
534
+ self.send_pdu_on_ctrl(hid_message)
328
535
 
329
- def virtual_cable_unplug(self):
330
- msg = VirtualCableUnplug()
331
- hid_message = bytes(msg)
332
- logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
333
- self.send_pdu_on_ctrl(msg)
536
+ @override
537
+ def on_ctrl_pdu(self, pdu: bytes) -> None:
538
+ logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
539
+ param = pdu[0] & 0x0F
540
+ message_type = pdu[0] >> 4
541
+ if message_type == Message.MessageType.HANDSHAKE:
542
+ logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
543
+ self.emit('handshake', Message.Handshake(param))
544
+ elif message_type == Message.MessageType.DATA:
545
+ logger.debug('<<< HID CONTROL DATA')
546
+ self.emit('control_data', pdu)
547
+ elif message_type == Message.MessageType.CONTROL:
548
+ if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
549
+ logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
550
+ self.emit('virtual_cable_unplug')
551
+ else:
552
+ logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
553
+ else:
554
+ logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')