swbt-python 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
swbt/report_loop.py ADDED
@@ -0,0 +1,115 @@
1
+ """Input report sender used by SwitchGamepad."""
2
+
3
+ import asyncio
4
+ from collections import deque
5
+ from contextlib import suppress
6
+
7
+ from swbt.diagnostics import DiagnosticsRecorder
8
+ from swbt.protocol.input_report import InputReportBuilder
9
+ from swbt.state_store import InputStateStore
10
+ from swbt.transport.base import HidDeviceTransport
11
+
12
+ REPLY_PERIODIC_HOLDOFF_SECONDS = 0.3
13
+
14
+
15
+ class ReportLoop:
16
+ """Send input reports from the current input state."""
17
+
18
+ def __init__(
19
+ self,
20
+ *,
21
+ transport: HidDeviceTransport,
22
+ state_store: InputStateStore,
23
+ report_period_us: int = 8000,
24
+ input_report_builder: InputReportBuilder | None = None,
25
+ diagnostics: DiagnosticsRecorder | None = None,
26
+ ) -> None:
27
+ """Create a report loop helper."""
28
+ self._transport = transport
29
+ self._state_store = state_store
30
+ self._report_period_seconds = report_period_us / 1_000_000
31
+ self._input_report_builder = input_report_builder or InputReportBuilder()
32
+ self._diagnostics = diagnostics
33
+ self._reply_queue: deque[bytes] = deque()
34
+ self._timer = 0
35
+ self._periodic_holdoff_until = 0.0
36
+ self._send_lock = asyncio.Lock()
37
+ self._task: asyncio.Task[None] | None = None
38
+
39
+ def start(self) -> None:
40
+ """Start periodic input report transmission."""
41
+ if self._task is not None and not self._task.done():
42
+ return
43
+ self._task = asyncio.create_task(self._run(), name="swbt-report-loop")
44
+
45
+ async def stop(self) -> None:
46
+ """Stop periodic input report transmission."""
47
+ if self._task is None:
48
+ return
49
+ task = self._task
50
+ self._task = None
51
+ task.cancel()
52
+ with suppress(asyncio.CancelledError):
53
+ await task
54
+
55
+ async def send_current_input(self, *, reason: str = "input") -> None:
56
+ """Send one 0x30 input report for the current state."""
57
+ async with self._send_lock:
58
+ await self._send_current_input_locked(reason=reason)
59
+
60
+ async def send_subcommand_reply(self, report: bytes) -> None:
61
+ """Send one 0x21 subcommand reply with the shared report timer."""
62
+ async with self._send_lock:
63
+ await self._send_subcommand_reply_locked(report)
64
+ self._holdoff_periodic_after_reply()
65
+
66
+ def queue_reply(self, report: bytes) -> None:
67
+ """Queue one subcommand reply for priority transmission."""
68
+ self._reply_queue.append(bytes(report))
69
+
70
+ async def send_next_report(self) -> None:
71
+ """Send the next queued reply or current input report."""
72
+ async with self._send_lock:
73
+ if self._reply_queue:
74
+ await self._send_subcommand_reply_locked(self._reply_queue.popleft())
75
+ self._holdoff_periodic_after_reply()
76
+ return
77
+ if self._is_periodic_held_off():
78
+ return
79
+ await self._send_current_input_locked(reason="periodic")
80
+
81
+ async def _send_current_input_locked(self, *, reason: str) -> None:
82
+ state = await self._state_store.snapshot()
83
+ report = self._input_report_builder.build_0x30(state, timer=self._timer)
84
+ await self._send_report(report, reason=reason)
85
+ self._advance_timer()
86
+
87
+ async def _send_subcommand_reply_locked(self, report: bytes) -> None:
88
+ reply = bytearray(report)
89
+ if reply and reply[0] == 0x21:
90
+ reply[1] = self._timer
91
+ await self._send_report(bytes(reply), reason="subcommand_reply")
92
+ if reply and reply[0] == 0x21:
93
+ self._advance_timer()
94
+
95
+ async def _run(self) -> None:
96
+ while True:
97
+ await asyncio.sleep(self._report_period_seconds)
98
+ await self.send_next_report()
99
+
100
+ async def _send_report(self, report: bytes, *, reason: str) -> None:
101
+ await self._transport.send_interrupt(report)
102
+ report_id = report[0]
103
+ if self._diagnostics is not None:
104
+ self._diagnostics.record_report_tx(report_id=report_id, reason=reason)
105
+
106
+ def _advance_timer(self) -> None:
107
+ self._timer = (self._timer + 1) & 0xFF
108
+
109
+ def _holdoff_periodic_after_reply(self) -> None:
110
+ self._periodic_holdoff_until = (
111
+ asyncio.get_running_loop().time() + REPLY_PERIODIC_HOLDOFF_SECONDS
112
+ )
113
+
114
+ def _is_periodic_held_off(self) -> bool:
115
+ return asyncio.get_running_loop().time() < self._periodic_holdoff_until
swbt/state_store.py ADDED
@@ -0,0 +1,60 @@
1
+ """Async-safe input state storage."""
2
+
3
+ import asyncio
4
+
5
+ from swbt.input import Button, IMUFrame, InputState, Stick
6
+
7
+
8
+ class InputStateStore:
9
+ """Store the current immutable input state behind an async lock."""
10
+
11
+ def __init__(self, initial_state: InputState | None = None) -> None:
12
+ """Create a state store."""
13
+ self._state = initial_state or InputState.neutral()
14
+ self._lock = asyncio.Lock()
15
+
16
+ async def snapshot(self) -> InputState:
17
+ """Return the current input state."""
18
+ async with self._lock:
19
+ return self._state
20
+
21
+ @property
22
+ def current(self) -> InputState:
23
+ """Return the latest committed input state."""
24
+ return self._state
25
+
26
+ async def apply(self, state: InputState) -> InputState:
27
+ """Replace the current input state."""
28
+ async with self._lock:
29
+ self._state = state
30
+ return self._state
31
+
32
+ async def sticks(self, *, left: Stick | None = None, right: Stick | None = None) -> InputState:
33
+ """Replace one or both stick positions."""
34
+ async with self._lock:
35
+ self._state = self._state.with_sticks(left_stick=left, right_stick=right)
36
+ return self._state
37
+
38
+ async def imu(self, *frames: IMUFrame) -> InputState:
39
+ """Replace IMU frames."""
40
+ async with self._lock:
41
+ self._state = self._state.with_imu(*frames)
42
+ return self._state
43
+
44
+ async def press(self, *buttons: Button) -> InputState:
45
+ """Add buttons to the current input state."""
46
+ async with self._lock:
47
+ self._state = self._state.with_buttons((*self._state.buttons, *buttons))
48
+ return self._state
49
+
50
+ async def release(self, *buttons: Button) -> InputState:
51
+ """Remove buttons from the current input state."""
52
+ async with self._lock:
53
+ self._state = self._state.with_buttons(self._state.buttons.difference(buttons))
54
+ return self._state
55
+
56
+ async def neutral(self) -> InputState:
57
+ """Replace the current input state with neutral input."""
58
+ async with self._lock:
59
+ self._state = InputState.neutral()
60
+ return self._state
@@ -0,0 +1,3 @@
1
+ """HID transport implementations."""
2
+
3
+ __all__: tuple[str, ...] = ()
@@ -0,0 +1,33 @@
1
+ """ACL queue drain helper for Bumble HID transport."""
2
+
3
+ from collections.abc import Awaitable
4
+
5
+
6
+ async def drain_bumble_acl_queue(l2cap_channel: object) -> None:
7
+ """Wait until Bumble reports no pending ACL packets for the channel."""
8
+ connection = getattr(l2cap_channel, "connection", None)
9
+ connection_handle = getattr(connection, "handle", None)
10
+ acl_packet_queue = getattr(connection, "acl_packet_queue", None)
11
+ if acl_packet_queue is None and isinstance(connection_handle, int):
12
+ device = getattr(connection, "device", None)
13
+ host = getattr(device, "host", None)
14
+ get_data_packet_queue = getattr(host, "get_data_packet_queue", None)
15
+ if callable(get_data_packet_queue):
16
+ acl_packet_queue = get_data_packet_queue(connection_handle)
17
+ drain = getattr(acl_packet_queue, "drain", None)
18
+ if not isinstance(connection_handle, int) or not callable(drain):
19
+ return
20
+ try:
21
+ last_pending: int | None = None
22
+ while True:
23
+ pending = getattr(acl_packet_queue, "pending", 0)
24
+ if not isinstance(pending, int) or pending <= 0 or pending == last_pending:
25
+ return
26
+ last_pending = pending
27
+ drain_result = drain(connection_handle)
28
+ if isinstance(drain_result, Awaitable):
29
+ await drain_result
30
+ continue
31
+ return
32
+ except ValueError:
33
+ return
@@ -0,0 +1,34 @@
1
+ """HIDP helpers for Bumble HID transport."""
2
+
3
+ HID_OUTPUT_REPORT_TYPE = 0x02
4
+ HIDP_DATA_MESSAGE_TYPE = 0x0A
5
+ HID_GET_SET_SUCCESS = 0xFF
6
+ HID_GET_SET_UNSUPPORTED_REQUEST = 0x02
7
+ HID_CONTROL_PSM = 0x0011
8
+ HID_INTERRUPT_PSM = 0x0013
9
+
10
+
11
+ def decode_hidp_output_report(pdu: bytes) -> bytes | None:
12
+ if not pdu:
13
+ return None
14
+ message_type = pdu[0] >> 4
15
+ report_type = pdu[0] & 0x03
16
+ if message_type != HIDP_DATA_MESSAGE_TYPE:
17
+ return None
18
+ if report_type != HID_OUTPUT_REPORT_TYPE:
19
+ return None
20
+ return pdu[1:]
21
+
22
+
23
+ def hid_channel_name(psm: object) -> str:
24
+ if psm == HID_CONTROL_PSM:
25
+ return "control"
26
+ if psm == HID_INTERRUPT_PSM:
27
+ return "interrupt"
28
+ return "unknown"
29
+
30
+
31
+ def format_psm(psm: object) -> str:
32
+ if isinstance(psm, int):
33
+ return f"0x{psm:04x}"
34
+ return "unknown"
@@ -0,0 +1,220 @@
1
+ """Bumble JSON key store metadata helpers."""
2
+
3
+ import copy
4
+ import json
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Protocol, cast
8
+
9
+ from swbt.diagnostics import DiagnosticsRecorder
10
+ from swbt.errors import InvalidKeyStoreError
11
+
12
+ PREVIOUS_NAMESPACE_PREFIX = "swbt.previous::"
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class KeyStoreMetadata:
17
+ """Observed metadata for a Bumble JSON key store file."""
18
+
19
+ exists: bool
20
+ previous_exists: bool
21
+
22
+
23
+ def read_key_store_metadata(key_store_path: str | Path) -> KeyStoreMetadata:
24
+ """Read non-sensitive metadata from a Bumble JSON key store file."""
25
+ path = Path(key_store_path)
26
+ if not path.exists():
27
+ return KeyStoreMetadata(exists=False, previous_exists=False)
28
+ return KeyStoreMetadata(
29
+ exists=True,
30
+ previous_exists=_previous_generation_exists(path),
31
+ )
32
+
33
+
34
+ def _previous_generation_exists(key_store_path: Path) -> bool:
35
+ try:
36
+ key_store_data = json.loads(key_store_path.read_text(encoding="utf-8"))
37
+ except (OSError, ValueError):
38
+ return False
39
+ if not isinstance(key_store_data, dict):
40
+ return False
41
+ return any(str(namespace).startswith(PREVIOUS_NAMESPACE_PREFIX) for namespace in key_store_data)
42
+
43
+
44
+ class _BumbleJsonKeyStoreRuntime(Protocol):
45
+ namespace: str
46
+
47
+ async def load(
48
+ self,
49
+ ) -> tuple[dict[str, dict[str, dict[str, object]]], dict[str, dict[str, object]]]:
50
+ """Load the full JSON DB and this store's current key map."""
51
+
52
+ async def save(self, db: dict[str, dict[str, dict[str, object]]]) -> None:
53
+ """Save the full JSON DB."""
54
+
55
+ async def get(self, name: str) -> object | None:
56
+ """Return one key entry."""
57
+
58
+ async def get_all(self) -> list[tuple[str, object]]:
59
+ """Return all current key entries."""
60
+
61
+ async def delete(self, name: str) -> None:
62
+ """Delete one current key entry."""
63
+
64
+ async def delete_all(self) -> None:
65
+ """Delete all current key entries."""
66
+
67
+ async def get_resolving_keys(self) -> object:
68
+ """Return LE resolving keys for Bumble internals."""
69
+
70
+
71
+ class _CurrentPreviousJsonKeyStore:
72
+ """Bumble-compatible JSON key store with one previous generation."""
73
+
74
+ def __init__(
75
+ self,
76
+ *,
77
+ filename: str | Path,
78
+ namespace: str | None = None,
79
+ device: object | None = None,
80
+ ) -> None:
81
+ if namespace is None and device is None:
82
+ msg = "namespace or device is required"
83
+ raise ValueError(msg)
84
+ self._filename = Path(filename)
85
+ self._namespace = namespace
86
+ self._device = device
87
+ self.last_update_previous_saved = False
88
+
89
+ @classmethod
90
+ def from_device(
91
+ cls,
92
+ device: object,
93
+ *,
94
+ filename: str | Path,
95
+ ) -> "_CurrentPreviousJsonKeyStore":
96
+ """Create a store whose namespace follows the Bumble device address."""
97
+ return cls(filename=filename, device=device)
98
+
99
+ async def update(self, name: str, keys: object) -> None:
100
+ """Write current keys and keep the overwritten current value as previous."""
101
+ current_store = self._current_store()
102
+ db, current_key_map = await current_store.load()
103
+ previous_namespace = self._previous_namespace(current_store)
104
+ previous_key_map = copy.deepcopy(current_key_map)
105
+ self.last_update_previous_saved = bool(previous_key_map)
106
+ if previous_key_map:
107
+ db[previous_namespace] = previous_key_map
108
+ else:
109
+ db.pop(previous_namespace, None)
110
+ current_key_map.clear()
111
+ current_key_map[name] = cast("Any", keys).to_dict()
112
+ await current_store.save(db)
113
+
114
+ async def get(self, name: str) -> object | None:
115
+ """Return one current key entry."""
116
+ for peer_address, peer_keys in await self.get_all():
117
+ if peer_address == name:
118
+ return peer_keys
119
+ return None
120
+
121
+ async def get_all(self) -> list[tuple[str, object]]:
122
+ """Return current key entries only."""
123
+ entries = await self._current_store().get_all()
124
+ if len(entries) > 1:
125
+ msg = "key store contains multiple current peers"
126
+ raise InvalidKeyStoreError(msg)
127
+ return entries
128
+
129
+ async def delete(self, name: str) -> None:
130
+ """Delete one current key entry."""
131
+ await self._current_store().delete(name)
132
+
133
+ async def delete_all(self) -> None:
134
+ """Delete all current key entries."""
135
+ await self._current_store().delete_all()
136
+
137
+ async def get_resolving_keys(self) -> object:
138
+ """Return current LE resolving keys for Bumble internals."""
139
+ return await self._current_store().get_resolving_keys()
140
+
141
+ def _current_store(self) -> _BumbleJsonKeyStoreRuntime:
142
+ from bumble.keys import JsonKeyStore # noqa: PLC0415
143
+
144
+ if self._device is not None:
145
+ return cast(
146
+ "_BumbleJsonKeyStoreRuntime",
147
+ JsonKeyStore.from_device(
148
+ cast("Any", self._device),
149
+ filename=str(self._filename),
150
+ ),
151
+ )
152
+ return cast(
153
+ "_BumbleJsonKeyStoreRuntime",
154
+ JsonKeyStore(self._namespace, str(self._filename)),
155
+ )
156
+
157
+ def _previous_namespace(self, current_store: object) -> str:
158
+ return f"{PREVIOUS_NAMESPACE_PREFIX}{cast('Any', current_store).namespace}"
159
+
160
+
161
+ class _DiagnosticKeyStore:
162
+ """Key store wrapper that records write outcome without logging key material."""
163
+
164
+ def __init__(self, key_store: object, diagnostics: DiagnosticsRecorder) -> None:
165
+ self._key_store = key_store
166
+ self._diagnostics = diagnostics
167
+
168
+ async def update(self, name: str, keys: object) -> None:
169
+ """Record key-store write success or failure."""
170
+ try:
171
+ await cast("Any", self._key_store).update(name, keys)
172
+ except Exception as error:
173
+ fields = self._generation_fields()
174
+ self._diagnostics.record_event(
175
+ "key_store_update",
176
+ **fields,
177
+ error_type=type(error).__name__,
178
+ message=str(error),
179
+ peer_address=name,
180
+ status="failed",
181
+ )
182
+ raise
183
+ fields = self._generation_fields()
184
+ self._diagnostics.record_event(
185
+ "key_store_update",
186
+ **fields,
187
+ peer_address=name,
188
+ status="succeeded",
189
+ )
190
+
191
+ async def get(self, name: str) -> object | None:
192
+ """Delegate key lookup."""
193
+ return await cast("Any", self._key_store).get(name)
194
+
195
+ async def get_all(self) -> list[tuple[str, object]]:
196
+ """Delegate full key listing."""
197
+ return await cast("Any", self._key_store).get_all()
198
+
199
+ async def delete(self, name: str) -> None:
200
+ """Delegate key deletion."""
201
+ await cast("Any", self._key_store).delete(name)
202
+
203
+ async def delete_all(self) -> None:
204
+ """Delegate all-key deletion."""
205
+ await cast("Any", self._key_store).delete_all()
206
+
207
+ async def get_resolving_keys(self) -> object:
208
+ """Delegate LE resolving-key lookup for Bumble internals."""
209
+ return await cast("Any", self._key_store).get_resolving_keys()
210
+
211
+ def __getattr__(self, name: str) -> object:
212
+ return getattr(self._key_store, name)
213
+
214
+ def _generation_fields(self) -> dict[str, object]:
215
+ if not isinstance(self._key_store, _CurrentPreviousJsonKeyStore):
216
+ return {}
217
+ return {
218
+ "generation": "current",
219
+ "previous_saved": self._key_store.last_update_previous_saved,
220
+ }