swbt-python 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
swbt/__init__.py ADDED
@@ -0,0 +1,37 @@
1
+ """Python API for virtual NX-compatible input devices."""
2
+
3
+ from swbt.diagnostics import DiagnosticsConfig, GamepadStatus
4
+ from swbt.errors import (
5
+ ClosedError,
6
+ ConnectionFailedError,
7
+ ConnectionTimeoutError,
8
+ InvalidInputError,
9
+ InvalidKeyStoreError,
10
+ SwbtError,
11
+ TransportOpenError,
12
+ )
13
+ from swbt.gamepad import ConnectionResult, SwitchGamepad, SwitchGamepadConfig
14
+ from swbt.input import Button, IMUFrame, InputState, Stick
15
+ from swbt.transport.base import BondedPeer, DisconnectRequestResult, HidDeviceTransport
16
+
17
+ __all__ = (
18
+ "BondedPeer",
19
+ "Button",
20
+ "ClosedError",
21
+ "ConnectionFailedError",
22
+ "ConnectionResult",
23
+ "ConnectionTimeoutError",
24
+ "DiagnosticsConfig",
25
+ "DisconnectRequestResult",
26
+ "GamepadStatus",
27
+ "HidDeviceTransport",
28
+ "IMUFrame",
29
+ "InputState",
30
+ "InvalidInputError",
31
+ "InvalidKeyStoreError",
32
+ "Stick",
33
+ "SwbtError",
34
+ "SwitchGamepad",
35
+ "SwitchGamepadConfig",
36
+ "TransportOpenError",
37
+ )
swbt/diagnostics.py ADDED
@@ -0,0 +1,202 @@
1
+ """Minimal diagnostics state for gamepad lifecycle and callback errors."""
2
+
3
+ import json
4
+ import platform
5
+ from dataclasses import dataclass
6
+ from importlib.metadata import PackageNotFoundError, version
7
+ from typing import TextIO
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class DiagnosticsConfig:
12
+ """Diagnostics configuration accepted by SwitchGamepad.
13
+
14
+ Attributes:
15
+ trace_writer: Text stream that receives one JSON Lines diagnostics event per line.
16
+ """
17
+
18
+ trace_writer: TextIO | None = None
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class DiagnosticsEvent:
23
+ """One diagnostics event recorded by the gamepad.
24
+
25
+ Attributes:
26
+ event: Stable event name.
27
+ error_type: Exception type name for error events.
28
+ message: Human-readable error message for error events.
29
+ recoverable: Whether the error can be treated as recoverable.
30
+ fields: Event-specific structured fields.
31
+ """
32
+
33
+ event: str
34
+ error_type: str | None = None
35
+ message: str | None = None
36
+ recoverable: bool | None = None
37
+ fields: dict[str, object] | None = None
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class GamepadStatus:
42
+ """Snapshot of gamepad status exposed by SwitchGamepad.status().
43
+
44
+ Attributes:
45
+ connection_state: Current lifecycle state name.
46
+ report_counters: Sent report counts keyed by numeric report ID.
47
+ last_subcommand_id: Last observed subcommand ID, if any.
48
+ raw_rumble: Last raw rumble payload received from the host.
49
+ last_error: Latest diagnostics error event, if any.
50
+ """
51
+
52
+ connection_state: str
53
+ report_counters: dict[int, int]
54
+ last_subcommand_id: int | None
55
+ raw_rumble: bytes | None
56
+ last_error: DiagnosticsEvent | None
57
+
58
+
59
+ class DiagnosticsRecorder:
60
+ """Record a small in-memory diagnostics event history."""
61
+
62
+ def __init__(self, trace_writer: TextIO | None = None) -> None:
63
+ """Create an empty recorder."""
64
+ self._events: list[DiagnosticsEvent] = []
65
+ self._report_counters: dict[int, int] = {}
66
+ self._last_subcommand_id: int | None = None
67
+ self._raw_rumble: bytes | None = None
68
+ self._trace_writer = trace_writer
69
+
70
+ @property
71
+ def events(self) -> tuple[DiagnosticsEvent, ...]:
72
+ """Return recorded events in order."""
73
+ return tuple(self._events)
74
+
75
+ @property
76
+ def report_counters(self) -> dict[int, int]:
77
+ """Return sent report counters keyed by report ID."""
78
+ return dict(self._report_counters)
79
+
80
+ @property
81
+ def last_subcommand_id(self) -> int | None:
82
+ """Return the last observed subcommand ID."""
83
+ return self._last_subcommand_id
84
+
85
+ @property
86
+ def raw_rumble(self) -> bytes | None:
87
+ """Return the last observed raw rumble payload."""
88
+ return self._raw_rumble
89
+
90
+ @property
91
+ def last_error(self) -> DiagnosticsEvent | None:
92
+ """Return the latest error event."""
93
+ for event in reversed(self._events):
94
+ if event.event == "error":
95
+ return event
96
+ return None
97
+
98
+ def record_event(self, event: str, **fields: object) -> DiagnosticsEvent:
99
+ """Record a diagnostics event with schema fields."""
100
+ diagnostics_event = DiagnosticsEvent(event=event, fields=dict(fields))
101
+ self._append(diagnostics_event)
102
+ return diagnostics_event
103
+
104
+ def record_report_tx(self, *, report_id: int, reason: str) -> DiagnosticsEvent:
105
+ """Record one sent report and increment its counter."""
106
+ counter = self._report_counters.get(report_id, 0) + 1
107
+ self._report_counters[report_id] = counter
108
+ return self.record_event(
109
+ "report_tx",
110
+ counter=counter,
111
+ reason=reason,
112
+ report_id=f"0x{report_id:02x}",
113
+ )
114
+
115
+ def record_subcommand_rx(self, *, packet_id: int | None, subcommand_id: int) -> None:
116
+ """Record the latest observed subcommand ID."""
117
+ self._last_subcommand_id = subcommand_id
118
+ self.record_event(
119
+ "subcommand_rx",
120
+ packet_id=packet_id,
121
+ subcommand_id=f"0x{subcommand_id:02x}",
122
+ )
123
+
124
+ def record_raw_rumble(self, raw_rumble: bytes) -> None:
125
+ """Record the latest raw rumble bytes."""
126
+ self._raw_rumble = bytes(raw_rumble)
127
+
128
+ def record_run_metadata(
129
+ self,
130
+ *,
131
+ adapter: str,
132
+ key_store_exists: bool | None = None,
133
+ key_store_path: str | None = None,
134
+ key_store_previous_exists: bool | None = None,
135
+ ) -> DiagnosticsEvent:
136
+ """Record environment metadata for one diagnostics run."""
137
+ fields: dict[str, object] = {
138
+ "adapter": adapter,
139
+ "os": platform.system(),
140
+ "package_version": self._package_version(),
141
+ "python_version": platform.python_version(),
142
+ }
143
+ if key_store_path is not None:
144
+ fields["key_store_path"] = key_store_path
145
+ if key_store_exists is not None:
146
+ fields["key_store_exists"] = key_store_exists
147
+ if key_store_previous_exists is not None:
148
+ fields["key_store_previous_exists"] = key_store_previous_exists
149
+ return self.record_event(
150
+ "run_metadata",
151
+ **fields,
152
+ )
153
+
154
+ def record_state_transition(
155
+ self,
156
+ *,
157
+ previous: str,
158
+ next_state: str,
159
+ reason: str,
160
+ ) -> DiagnosticsEvent:
161
+ """Record one lifecycle state transition."""
162
+ return self.record_event(
163
+ "state_transition",
164
+ previous=previous,
165
+ next=next_state,
166
+ reason=reason,
167
+ )
168
+
169
+ def record_error(self, error: BaseException, *, recoverable: bool) -> DiagnosticsEvent:
170
+ """Record an exception as an error event."""
171
+ event = DiagnosticsEvent(
172
+ event="error",
173
+ error_type=type(error).__name__,
174
+ message=str(error),
175
+ recoverable=recoverable,
176
+ )
177
+ self._append(event)
178
+ return event
179
+
180
+ def _append(self, event: DiagnosticsEvent) -> None:
181
+ self._events.append(event)
182
+ if self._trace_writer is None:
183
+ return
184
+ payload: dict[str, object] = {"event": event.event}
185
+ if event.fields is not None:
186
+ payload.update(event.fields)
187
+ if event.error_type is not None:
188
+ payload["error_type"] = event.error_type
189
+ if event.message is not None:
190
+ payload["message"] = event.message
191
+ if event.recoverable is not None:
192
+ payload["recoverable"] = event.recoverable
193
+ self._trace_writer.write(json.dumps(payload, separators=(",", ":"), sort_keys=True))
194
+ self._trace_writer.write("\n")
195
+ self._trace_writer.flush()
196
+
197
+ @staticmethod
198
+ def _package_version() -> str:
199
+ try:
200
+ return version("swbt-python")
201
+ except PackageNotFoundError:
202
+ return "unknown"
swbt/errors.py ADDED
@@ -0,0 +1,33 @@
1
+ """Exception types exposed by swbt."""
2
+
3
+
4
+ class SwbtError(Exception):
5
+ """Base exception for swbt errors."""
6
+
7
+
8
+ class TransportOpenError(SwbtError):
9
+ """Raised when a transport cannot be opened."""
10
+
11
+
12
+ class ConnectionTimeoutError(SwbtError):
13
+ """Raised when waiting for a connection times out."""
14
+
15
+
16
+ class ConnectionFailedError(SwbtError):
17
+ """Raised when a connection attempt finishes without a connection."""
18
+
19
+
20
+ class InvalidKeyStoreError(SwbtError):
21
+ """Raised when a key store has an unsupported or invalid shape."""
22
+
23
+
24
+ class ProtocolError(SwbtError):
25
+ """Raised when protocol bytes cannot be parsed or produced."""
26
+
27
+
28
+ class ClosedError(SwbtError):
29
+ """Raised when an operation requires an open controller."""
30
+
31
+
32
+ class InvalidInputError(SwbtError):
33
+ """Raised when user-provided input values are outside the supported range."""
@@ -0,0 +1,14 @@
1
+ """Public gamepad facade."""
2
+
3
+ from swbt.gamepad.connection import ConnectionResult, ConnectionStatus
4
+ from swbt.gamepad.core import SwitchGamepad, SwitchGamepadConfig
5
+
6
+ DISCONNECT_REQUEST_TIMEOUT_SECONDS = 0.25
7
+
8
+ __all__ = (
9
+ "DISCONNECT_REQUEST_TIMEOUT_SECONDS",
10
+ "ConnectionResult",
11
+ "ConnectionStatus",
12
+ "SwitchGamepad",
13
+ "SwitchGamepadConfig",
14
+ )
@@ -0,0 +1,222 @@
1
+ """Connection workflow for SwitchGamepad."""
2
+
3
+ import asyncio
4
+ from collections.abc import Awaitable, Callable
5
+ from dataclasses import dataclass
6
+ from typing import Literal
7
+
8
+ from swbt.diagnostics import DiagnosticsRecorder
9
+ from swbt.errors import (
10
+ ClosedError,
11
+ ConnectionFailedError,
12
+ ConnectionTimeoutError,
13
+ InvalidKeyStoreError,
14
+ )
15
+ from swbt.transport.base import HidDeviceTransport
16
+
17
+ ConnectionRoute = Literal["active_reconnect", "pairing"]
18
+ ConnectionStatus = Literal["connected", "no_bond", "timeout", "failed"]
19
+
20
+ EnsureOpen = Callable[[], Awaitable[None]]
21
+ TransportProvider = Callable[[], HidDeviceTransport | None]
22
+ StateSetter = Callable[[str], None]
23
+ EventClearer = Callable[[], None]
24
+ WaitForConnected = Callable[[float | None], Awaitable[None]]
25
+ CloseNeutral = Callable[[], Awaitable[None]]
26
+ PairWithTimeout = Callable[[float | None], Awaitable[None]]
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class ConnectionResult:
31
+ """Result of an explicit connection strategy.
32
+
33
+ Attributes:
34
+ route: Connection path that produced the result.
35
+ status: Outcome of the connection attempt.
36
+ peer_address: Address of the bonded peer used for reconnect, when one was selected.
37
+ peer_count: Number of bonded peers observed while selecting a reconnect target.
38
+ """
39
+
40
+ route: ConnectionRoute
41
+ status: ConnectionStatus
42
+ peer_address: str | None = None
43
+ peer_count: int | None = None
44
+
45
+
46
+ @dataclass
47
+ class ConnectionWorkflow:
48
+ """Run active reconnect and pairing fallback workflows."""
49
+
50
+ clear_connected: EventClearer
51
+ close_neutral: CloseNeutral
52
+ diagnostics: DiagnosticsRecorder
53
+ ensure_open: EnsureOpen
54
+ get_transport: TransportProvider
55
+ key_store_path: str | None
56
+ pair: PairWithTimeout
57
+ set_connection_state: StateSetter
58
+ transport_was_injected: bool
59
+ wait_for_connected: WaitForConnected
60
+
61
+ async def try_reconnect(
62
+ self,
63
+ timeout: float | None = None, # noqa: ASYNC109
64
+ ) -> ConnectionResult:
65
+ """Try active reconnect with exactly one bonded peer."""
66
+ await self.ensure_open()
67
+ transport = self._transport()
68
+ if self.key_store_path is None and not self.transport_was_injected:
69
+ self.diagnostics.record_event(
70
+ "reconnect_key_store_unavailable",
71
+ reason="key_store_path_none",
72
+ route="active_reconnect",
73
+ )
74
+ peers = await transport.list_bonded_peers()
75
+ if len(peers) > 1:
76
+ self.diagnostics.record_event(
77
+ "invalid_key_store",
78
+ peer_count=len(peers),
79
+ reason="multiple_current_peers",
80
+ )
81
+ msg = "key store contains multiple current peers"
82
+ raise InvalidKeyStoreError(msg)
83
+ selection = _bonded_peer_selection(len(peers))
84
+ self.diagnostics.record_event(
85
+ "bonded_peers_discovered",
86
+ peer_count=len(peers),
87
+ selection=selection,
88
+ )
89
+ if not peers:
90
+ self.diagnostics.record_event(
91
+ "active_reconnect_result",
92
+ peer_count=0,
93
+ route="active_reconnect",
94
+ status="no_bond",
95
+ )
96
+ return ConnectionResult(
97
+ route="active_reconnect",
98
+ status="no_bond",
99
+ peer_count=0,
100
+ )
101
+ peer = peers[0]
102
+ self.set_connection_state("reconnecting")
103
+ self.clear_connected()
104
+ self.diagnostics.record_event(
105
+ "active_reconnect_attempt",
106
+ peer_address=peer.address,
107
+ route="active_reconnect",
108
+ )
109
+ try:
110
+ await transport.connect_bonded_peer(
111
+ peer.address,
112
+ connect_timeout=timeout,
113
+ )
114
+ await self.wait_for_connected(timeout)
115
+ except TimeoutError:
116
+ self.diagnostics.record_event(
117
+ "active_reconnect_result",
118
+ failure_reason="connection_timeout",
119
+ peer_address=peer.address,
120
+ route="active_reconnect",
121
+ status="timeout",
122
+ )
123
+ await self.close_neutral()
124
+ return ConnectionResult(
125
+ route="active_reconnect",
126
+ status="timeout",
127
+ peer_address=peer.address,
128
+ peer_count=1,
129
+ )
130
+ except asyncio.CancelledError as error:
131
+ if _current_task_is_cancelling():
132
+ raise
133
+ return await self._record_transport_error(error, peer_address=peer.address)
134
+ except Exception as error: # noqa: BLE001
135
+ return await self._record_transport_error(error, peer_address=peer.address)
136
+
137
+ self.diagnostics.record_event(
138
+ "active_reconnect_result",
139
+ peer_address=peer.address,
140
+ route="active_reconnect",
141
+ status="connected",
142
+ )
143
+ return ConnectionResult(
144
+ route="active_reconnect",
145
+ status="connected",
146
+ peer_address=peer.address,
147
+ peer_count=1,
148
+ )
149
+
150
+ async def try_connect(
151
+ self,
152
+ *,
153
+ timeout: float | None = None, # noqa: ASYNC109
154
+ allow_pairing: bool = False,
155
+ ) -> ConnectionResult:
156
+ """Try bonded reconnect first, then optional pairing fallback."""
157
+ reconnect_result = await self.try_reconnect(timeout=timeout)
158
+ if reconnect_result.status != "no_bond" or not allow_pairing:
159
+ return reconnect_result
160
+ self.diagnostics.record_event(
161
+ "connect_pairing_fallback",
162
+ reason="no_bond",
163
+ route="pairing",
164
+ )
165
+ try:
166
+ await self.pair(timeout)
167
+ except ConnectionTimeoutError:
168
+ return ConnectionResult(route="pairing", status="timeout")
169
+ return ConnectionResult(route="pairing", status="connected")
170
+
171
+ def _transport(self) -> HidDeviceTransport:
172
+ transport = self.get_transport()
173
+ if transport is None:
174
+ msg = "gamepad is not open"
175
+ raise ClosedError(msg)
176
+ return transport
177
+
178
+ async def _record_transport_error(
179
+ self,
180
+ error: BaseException,
181
+ *,
182
+ peer_address: str,
183
+ ) -> ConnectionResult:
184
+ self.diagnostics.record_event(
185
+ "active_reconnect_result",
186
+ error_type=type(error).__name__,
187
+ failure_reason="transport_error",
188
+ message=str(error),
189
+ peer_address=peer_address,
190
+ route="active_reconnect",
191
+ status="failed",
192
+ )
193
+ self.diagnostics.record_error(error, recoverable=True)
194
+ await self.close_neutral()
195
+ return ConnectionResult(
196
+ route="active_reconnect",
197
+ status="failed",
198
+ peer_address=peer_address,
199
+ peer_count=1,
200
+ )
201
+
202
+
203
+ def raise_if_connection_failed(result: ConnectionResult) -> None:
204
+ """Raise the public connection error for a non-connected result."""
205
+ if result.status == "connected":
206
+ return
207
+ if result.status == "timeout":
208
+ msg = "connection timed out"
209
+ raise ConnectionTimeoutError(msg)
210
+ msg = f"connection failed: {result.status}"
211
+ raise ConnectionFailedError(msg)
212
+
213
+
214
+ def _bonded_peer_selection(peer_count: int) -> str:
215
+ if peer_count == 0:
216
+ return "none"
217
+ return "selected"
218
+
219
+
220
+ def _current_task_is_cancelling() -> bool:
221
+ task = asyncio.current_task()
222
+ return task is not None and task.cancelling() > 0