bumble 0.0.179__py3-none-any.whl → 0.0.181__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bumble/hid.py CHANGED
@@ -19,16 +19,17 @@ from __future__ import annotations
19
19
  from dataclasses import dataclass
20
20
  import logging
21
21
  import enum
22
+ import struct
22
23
 
24
+ from abc import ABC, abstractmethod
23
25
  from pyee import EventEmitter
24
- from typing import Optional, TYPE_CHECKING
26
+ from typing import Optional, Callable, TYPE_CHECKING
27
+ from typing_extensions import override
25
28
 
26
- from bumble import l2cap
29
+ from bumble import l2cap, device
27
30
  from bumble.colors import color
28
31
  from bumble.core import InvalidStateError, ProtocolError
29
-
30
- if TYPE_CHECKING:
31
- from bumble.device import Device, Connection
32
+ from .hci import Address
32
33
 
33
34
 
34
35
  # -----------------------------------------------------------------------------
@@ -60,6 +61,7 @@ class Message:
60
61
  NOT_READY = 0x01
61
62
  ERR_INVALID_REPORT_ID = 0x02
62
63
  ERR_UNSUPPORTED_REQUEST = 0x03
64
+ ERR_INVALID_PARAMETER = 0x04
63
65
  ERR_UNKNOWN = 0x0E
64
66
  ERR_FATAL = 0x0F
65
67
 
@@ -101,13 +103,14 @@ class GetReportMessage(Message):
101
103
  def __bytes__(self) -> bytes:
102
104
  packet_bytes = bytearray()
103
105
  packet_bytes.append(self.report_id)
104
- packet_bytes.extend(
105
- [(self.buffer_size & 0xFF), ((self.buffer_size >> 8) & 0xFF)]
106
- )
107
- if self.report_type == Message.ReportType.OTHER_REPORT:
106
+ if self.buffer_size == 0:
108
107
  return self.header(self.report_type) + packet_bytes
109
108
  else:
110
- return self.header(0x08 | self.report_type) + packet_bytes
109
+ return (
110
+ self.header(0x08 | self.report_type)
111
+ + packet_bytes
112
+ + struct.pack("<H", self.buffer_size)
113
+ )
111
114
 
112
115
 
113
116
  @dataclass
@@ -120,6 +123,16 @@ class SetReportMessage(Message):
120
123
  return self.header(self.report_type) + self.data
121
124
 
122
125
 
126
+ @dataclass
127
+ class SendControlData(Message):
128
+ report_type: int
129
+ data: bytes
130
+ message_type = Message.MessageType.DATA
131
+
132
+ def __bytes__(self) -> bytes:
133
+ return self.header(self.report_type) + self.data
134
+
135
+
123
136
  @dataclass
124
137
  class GetProtocolMessage(Message):
125
138
  message_type = Message.MessageType.GET_PROTOCOL
@@ -161,31 +174,47 @@ class VirtualCableUnplug(Message):
161
174
  return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
162
175
 
163
176
 
177
+ # Device sends input report, host sends output report.
164
178
  @dataclass
165
179
  class SendData(Message):
166
180
  data: bytes
181
+ report_type: int
167
182
  message_type = Message.MessageType.DATA
168
183
 
169
184
  def __bytes__(self) -> bytes:
170
- return self.header(Message.ReportType.OUTPUT_REPORT) + self.data
185
+ return self.header(self.report_type) + self.data
186
+
187
+
188
+ @dataclass
189
+ class SendHandshakeMessage(Message):
190
+ result_code: int
191
+ message_type = Message.MessageType.HANDSHAKE
192
+
193
+ def __bytes__(self) -> bytes:
194
+ return self.header(self.result_code)
171
195
 
172
196
 
173
197
  # -----------------------------------------------------------------------------
174
- class Host(EventEmitter):
175
- l2cap_ctrl_channel: Optional[l2cap.ClassicChannel]
176
- l2cap_intr_channel: Optional[l2cap.ClassicChannel]
198
+ class HID(ABC, EventEmitter):
199
+ l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
200
+ l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
201
+ connection: Optional[device.Connection] = None
202
+
203
+ class Role(enum.IntEnum):
204
+ HOST = 0x00
205
+ DEVICE = 0x01
177
206
 
178
- def __init__(self, device: Device, connection: Connection) -> None:
207
+ def __init__(self, device: device.Device, role: Role) -> None:
179
208
  super().__init__()
209
+ self.remote_device_bd_address: Optional[Address] = None
180
210
  self.device = device
181
- self.connection = connection
182
-
183
- self.l2cap_ctrl_channel = None
184
- self.l2cap_intr_channel = None
211
+ self.role = role
185
212
 
186
213
  # Register ourselves with the L2CAP channel manager
187
- device.register_l2cap_server(HID_CONTROL_PSM, self.on_connection)
188
- device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_connection)
214
+ device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
215
+ device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
216
+
217
+ device.on('connection', self.on_device_connection)
189
218
 
190
219
  async def connect_control_channel(self) -> None:
191
220
  # Create a new L2CAP connection - control channel
@@ -229,9 +258,18 @@ class Host(EventEmitter):
229
258
  self.l2cap_ctrl_channel = None
230
259
  await channel.disconnect()
231
260
 
232
- def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
261
+ def on_device_connection(self, connection: device.Connection) -> None:
262
+ self.connection = connection
263
+ self.remote_device_bd_address = connection.peer_address
264
+ connection.on('disconnection', self.on_device_disconnection)
265
+
266
+ def on_device_disconnection(self, reason: int) -> None:
267
+ self.connection = None
268
+
269
+ def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
233
270
  logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
234
271
  l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
272
+ l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
235
273
 
236
274
  def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
237
275
  if l2cap_channel.psm == HID_CONTROL_PSM:
@@ -242,37 +280,220 @@ class Host(EventEmitter):
242
280
  self.l2cap_intr_channel.sink = self.on_intr_pdu
243
281
  logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
244
282
 
283
+ def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
284
+ if l2cap_channel.psm == HID_CONTROL_PSM:
285
+ self.l2cap_ctrl_channel = None
286
+ else:
287
+ self.l2cap_intr_channel = None
288
+ logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
289
+
290
+ @abstractmethod
291
+ def on_ctrl_pdu(self, pdu: bytes) -> None:
292
+ pass
293
+
294
+ def on_intr_pdu(self, pdu: bytes) -> None:
295
+ logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
296
+ self.emit("interrupt_data", pdu)
297
+
298
+ def send_pdu_on_ctrl(self, msg: bytes) -> None:
299
+ assert self.l2cap_ctrl_channel
300
+ self.l2cap_ctrl_channel.send_pdu(msg)
301
+
302
+ def send_pdu_on_intr(self, msg: bytes) -> None:
303
+ assert self.l2cap_intr_channel
304
+ self.l2cap_intr_channel.send_pdu(msg)
305
+
306
+ def send_data(self, data: bytes) -> None:
307
+ if self.role == HID.Role.HOST:
308
+ report_type = Message.ReportType.OUTPUT_REPORT
309
+ else:
310
+ report_type = Message.ReportType.INPUT_REPORT
311
+ msg = SendData(data, report_type)
312
+ hid_message = bytes(msg)
313
+ if self.l2cap_intr_channel is not None:
314
+ logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
315
+ self.send_pdu_on_intr(hid_message)
316
+
317
+ def virtual_cable_unplug(self) -> None:
318
+ msg = VirtualCableUnplug()
319
+ hid_message = bytes(msg)
320
+ logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
321
+ self.send_pdu_on_ctrl(hid_message)
322
+
323
+
324
+ # -----------------------------------------------------------------------------
325
+
326
+
327
+ class Device(HID):
328
+ class GetSetReturn(enum.IntEnum):
329
+ FAILURE = 0x00
330
+ REPORT_ID_NOT_FOUND = 0x01
331
+ ERR_UNSUPPORTED_REQUEST = 0x02
332
+ ERR_UNKNOWN = 0x03
333
+ ERR_INVALID_PARAMETER = 0x04
334
+ SUCCESS = 0xFF
335
+
336
+ class GetSetStatus:
337
+ def __init__(self) -> None:
338
+ self.data = bytearray()
339
+ self.status = 0
340
+
341
+ def __init__(self, device: device.Device) -> None:
342
+ super().__init__(device, HID.Role.DEVICE)
343
+ get_report_cb: Optional[Callable[[int, int, int], None]] = None
344
+ set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
345
+ get_protocol_cb: Optional[Callable[[], None]] = None
346
+ set_protocol_cb: Optional[Callable[[int], None]] = None
347
+
348
+ @override
245
349
  def on_ctrl_pdu(self, pdu: bytes) -> None:
246
350
  logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
247
- # Here we will receive all kinds of packets, parse and then call respective callbacks
248
- message_type = pdu[0] >> 4
249
351
  param = pdu[0] & 0x0F
352
+ message_type = pdu[0] >> 4
250
353
 
251
- if message_type == Message.MessageType.HANDSHAKE:
252
- logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
253
- self.emit('handshake', Message.Handshake(param))
354
+ if message_type == Message.MessageType.GET_REPORT:
355
+ logger.debug('<<< HID GET REPORT')
356
+ self.handle_get_report(pdu)
357
+ elif message_type == Message.MessageType.SET_REPORT:
358
+ logger.debug('<<< HID SET REPORT')
359
+ self.handle_set_report(pdu)
360
+ elif message_type == Message.MessageType.GET_PROTOCOL:
361
+ logger.debug('<<< HID GET PROTOCOL')
362
+ self.handle_get_protocol(pdu)
363
+ elif message_type == Message.MessageType.SET_PROTOCOL:
364
+ logger.debug('<<< HID SET PROTOCOL')
365
+ self.handle_set_protocol(pdu)
254
366
  elif message_type == Message.MessageType.DATA:
255
367
  logger.debug('<<< HID CONTROL DATA')
256
- self.emit('data', pdu)
368
+ self.emit('control_data', pdu)
257
369
  elif message_type == Message.MessageType.CONTROL:
258
370
  if param == Message.ControlCommand.SUSPEND:
259
371
  logger.debug('<<< HID SUSPEND')
260
- self.emit('suspend', pdu)
372
+ self.emit('suspend')
261
373
  elif param == Message.ControlCommand.EXIT_SUSPEND:
262
374
  logger.debug('<<< HID EXIT SUSPEND')
263
- self.emit('exit_suspend', pdu)
375
+ self.emit('exit_suspend')
264
376
  elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
265
377
  logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
266
378
  self.emit('virtual_cable_unplug')
267
379
  else:
268
380
  logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
269
381
  else:
270
- logger.debug('<<< HID CONTROL DATA')
271
- self.emit('data', pdu)
382
+ logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
383
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
272
384
 
273
- def on_intr_pdu(self, pdu: bytes) -> None:
274
- logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
275
- self.emit("data", pdu)
385
+ def send_handshake_message(self, result_code: int) -> None:
386
+ msg = SendHandshakeMessage(result_code)
387
+ hid_message = bytes(msg)
388
+ logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
389
+ self.send_pdu_on_ctrl(hid_message)
390
+
391
+ def send_control_data(self, report_type: int, data: bytes):
392
+ msg = SendControlData(report_type=report_type, data=data)
393
+ hid_message = bytes(msg)
394
+ logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
395
+ self.send_pdu_on_ctrl(hid_message)
396
+
397
+ def handle_get_report(self, pdu: bytes):
398
+ if self.get_report_cb is None:
399
+ logger.debug("GetReport callback not registered !!")
400
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
401
+ return
402
+ report_type = pdu[0] & 0x03
403
+ buffer_flag = (pdu[0] & 0x08) >> 3
404
+ report_id = pdu[1]
405
+ logger.debug(f"buffer_flag: {buffer_flag}")
406
+ if buffer_flag == 1:
407
+ buffer_size = (pdu[3] << 8) | pdu[2]
408
+ else:
409
+ buffer_size = 0
410
+
411
+ ret = self.get_report_cb(report_id, report_type, buffer_size)
412
+ assert ret is not None
413
+ if ret.status == self.GetSetReturn.FAILURE:
414
+ self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
415
+ elif ret.status == self.GetSetReturn.SUCCESS:
416
+ data = bytearray()
417
+ data.append(report_id)
418
+ data.extend(ret.data)
419
+ if len(data) < self.l2cap_ctrl_channel.mtu: # type: ignore[union-attr]
420
+ self.send_control_data(report_type=report_type, data=data)
421
+ else:
422
+ self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
423
+ elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
424
+ self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
425
+ elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
426
+ self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
427
+ elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
428
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
429
+
430
+ def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
431
+ self.get_report_cb = cb
432
+ logger.debug("GetReport callback registered successfully")
433
+
434
+ def handle_set_report(self, pdu: bytes):
435
+ if self.set_report_cb is None:
436
+ logger.debug("SetReport callback not registered !!")
437
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
438
+ return
439
+ report_type = pdu[0] & 0x03
440
+ report_id = pdu[1]
441
+ report_data = pdu[2:]
442
+ report_size = len(report_data) + 1
443
+ ret = self.set_report_cb(report_id, report_type, report_size, report_data)
444
+ assert ret is not None
445
+ if ret.status == self.GetSetReturn.SUCCESS:
446
+ self.send_handshake_message(Message.Handshake.SUCCESSFUL)
447
+ elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
448
+ self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
449
+ elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
450
+ self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
451
+ else:
452
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
453
+
454
+ def register_set_report_cb(
455
+ self, cb: Callable[[int, int, int, bytes], None]
456
+ ) -> None:
457
+ self.set_report_cb = cb
458
+ logger.debug("SetReport callback registered successfully")
459
+
460
+ def handle_get_protocol(self, pdu: bytes):
461
+ if self.get_protocol_cb is None:
462
+ logger.debug("GetProtocol callback not registered !!")
463
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
464
+ return
465
+ ret = self.get_protocol_cb()
466
+ assert ret is not None
467
+ if ret.status == self.GetSetReturn.SUCCESS:
468
+ self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
469
+ else:
470
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
471
+
472
+ def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
473
+ self.get_protocol_cb = cb
474
+ logger.debug("GetProtocol callback registered successfully")
475
+
476
+ def handle_set_protocol(self, pdu: bytes):
477
+ if self.set_protocol_cb is None:
478
+ logger.debug("SetProtocol callback not registered !!")
479
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
480
+ return
481
+ ret = self.set_protocol_cb(pdu[0] & 0x01)
482
+ assert ret is not None
483
+ if ret.status == self.GetSetReturn.SUCCESS:
484
+ self.send_handshake_message(Message.Handshake.SUCCESSFUL)
485
+ else:
486
+ self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
487
+
488
+ def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
489
+ self.set_protocol_cb = cb
490
+ logger.debug("SetProtocol callback registered successfully")
491
+
492
+
493
+ # -----------------------------------------------------------------------------
494
+ class Host(HID):
495
+ def __init__(self, device: device.Device) -> None:
496
+ super().__init__(device, HID.Role.HOST)
276
497
 
277
498
  def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
278
499
  msg = GetReportMessage(
@@ -282,52 +503,52 @@ class Host(EventEmitter):
282
503
  logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
283
504
  self.send_pdu_on_ctrl(hid_message)
284
505
 
285
- def set_report(self, report_type: int, data: bytes):
506
+ def set_report(self, report_type: int, data: bytes) -> None:
286
507
  msg = SetReportMessage(report_type=report_type, data=data)
287
508
  hid_message = bytes(msg)
288
509
  logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
289
510
  self.send_pdu_on_ctrl(hid_message)
290
511
 
291
- def get_protocol(self):
512
+ def get_protocol(self) -> None:
292
513
  msg = GetProtocolMessage()
293
514
  hid_message = bytes(msg)
294
515
  logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
295
516
  self.send_pdu_on_ctrl(hid_message)
296
517
 
297
- def set_protocol(self, protocol_mode: int):
518
+ def set_protocol(self, protocol_mode: int) -> None:
298
519
  msg = SetProtocolMessage(protocol_mode=protocol_mode)
299
520
  hid_message = bytes(msg)
300
521
  logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
301
522
  self.send_pdu_on_ctrl(hid_message)
302
523
 
303
- def send_pdu_on_ctrl(self, msg: bytes) -> None:
304
- assert self.l2cap_ctrl_channel
305
- self.l2cap_ctrl_channel.send_pdu(msg)
306
-
307
- def send_pdu_on_intr(self, msg: bytes) -> None:
308
- assert self.l2cap_intr_channel
309
- self.l2cap_intr_channel.send_pdu(msg)
310
-
311
- def send_data(self, data):
312
- msg = SendData(data)
313
- hid_message = bytes(msg)
314
- logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
315
- self.send_pdu_on_intr(hid_message)
316
-
317
- def suspend(self):
524
+ def suspend(self) -> None:
318
525
  msg = Suspend()
319
526
  hid_message = bytes(msg)
320
527
  logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
321
- self.send_pdu_on_ctrl(msg)
528
+ self.send_pdu_on_ctrl(hid_message)
322
529
 
323
- def exit_suspend(self):
530
+ def exit_suspend(self) -> None:
324
531
  msg = ExitSuspend()
325
532
  hid_message = bytes(msg)
326
533
  logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
327
- self.send_pdu_on_ctrl(msg)
534
+ self.send_pdu_on_ctrl(hid_message)
328
535
 
329
- def virtual_cable_unplug(self):
330
- msg = VirtualCableUnplug()
331
- hid_message = bytes(msg)
332
- logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
333
- self.send_pdu_on_ctrl(msg)
536
+ @override
537
+ def on_ctrl_pdu(self, pdu: bytes) -> None:
538
+ logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
539
+ param = pdu[0] & 0x0F
540
+ message_type = pdu[0] >> 4
541
+ if message_type == Message.MessageType.HANDSHAKE:
542
+ logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
543
+ self.emit('handshake', Message.Handshake(param))
544
+ elif message_type == Message.MessageType.DATA:
545
+ logger.debug('<<< HID CONTROL DATA')
546
+ self.emit('control_data', pdu)
547
+ elif message_type == Message.MessageType.CONTROL:
548
+ if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
549
+ logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
550
+ self.emit('virtual_cable_unplug')
551
+ else:
552
+ logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
553
+ else:
554
+ logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')