lr-shuttle 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lr-shuttle might be problematic. Click here for more details.

shuttle/constants.py ADDED
@@ -0,0 +1,41 @@
1
+ #! /usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """Shared constants for Shuttle CLI and library consumers."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import Dict, Set
8
+
9
+
10
+ SPI_ALLOWED_FIELDS: Set[str] = {
11
+ "cs_active",
12
+ "setup_us",
13
+ "bit_order",
14
+ "byte_order",
15
+ "clock_polarity",
16
+ "clock_phase",
17
+ "hz",
18
+ }
19
+
20
+
21
+ SPI_CHOICE_FIELDS: Dict[str, Set[str]] = {
22
+ "cs_active": {"low", "high"},
23
+ "bit_order": {"msb", "lsb"},
24
+ "byte_order": {"big", "little"},
25
+ "clock_polarity": {"idle_low", "idle_high"},
26
+ "clock_phase": {"leading", "trailing"},
27
+ }
28
+
29
+
30
+ UART_PARITY_ALIASES: Dict[str, str] = {
31
+ "n": "n",
32
+ "none": "n",
33
+ "e": "e",
34
+ "even": "e",
35
+ "o": "o",
36
+ "odd": "o",
37
+ }
38
+
39
+
40
+ DEFAULT_BAUD = 921600
41
+ DEFAULT_TIMEOUT = 2.0
shuttle/prodtest.py ADDED
@@ -0,0 +1,120 @@
1
+ #! /usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """Helpers for the prodtest SPI command protocol."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import Iterable, Sequence
8
+
9
+ from . import timo
10
+
11
+ RESET_OPCODE = ord("?")
12
+ PLUS_OPCODE = ord("+")
13
+ IO_SELF_TEST_OPCODE = ord("T")
14
+ IO_SELF_TEST_MASK_LEN = 8
15
+ IO_SELF_TEST_DUMMY_BYTE = 0xFF
16
+ IO_SELF_TEST_IRQ_TIMEOUT_US = 1_000_000
17
+
18
+
19
+ def _ensure_byte(value: int) -> int:
20
+ if not 0 <= value <= 0xFF:
21
+ raise ValueError("Prodtest arguments must be in range 0..255")
22
+ return value
23
+
24
+
25
+ def _build_command_bytes(opcode: int, arguments: Iterable[int] | bytes = ()) -> bytes:
26
+ """Build raw bytes from an opcode and a sequence of byte arguments."""
27
+
28
+ if isinstance(arguments, bytes):
29
+ payload = arguments
30
+ else:
31
+ payload = bytes(_ensure_byte(arg) for arg in arguments)
32
+ return bytes([opcode]) + payload
33
+
34
+
35
+ def command(opcode: int, arguments: Iterable[int] | bytes = ()) -> dict:
36
+ """Build an NDJSON-ready spi.xfer payload for a prodtest command."""
37
+
38
+ return timo.command_payload(_build_command_bytes(opcode, arguments))
39
+
40
+
41
+ def reset() -> dict:
42
+ """Return the prodtest reset command (single '?' byte)."""
43
+
44
+ return command(RESET_OPCODE)
45
+
46
+
47
+ def reset_transfer() -> dict:
48
+ """Reset command packaged as an NDJSON-ready payload."""
49
+
50
+ return reset()
51
+
52
+
53
+ def ping_sequence() -> Sequence[dict]:
54
+ """Return the two SPI frames for the prodtest ping action ('+' then dummy)."""
55
+ # First transfer: send '+' (PLUS_OPCODE), expect response (should be ignored)
56
+ # Second transfer: send dummy (0xFF), expect '-' (0x2D) back
57
+ return [
58
+ timo.command_payload(bytes([PLUS_OPCODE])),
59
+ timo.command_payload(bytes([0xFF])),
60
+ ]
61
+
62
+
63
+ def io_self_test(mask: bytes) -> Sequence[dict]:
64
+ """Return the two SPI frames required to run the GPIO self-test."""
65
+
66
+ if len(mask) != IO_SELF_TEST_MASK_LEN:
67
+ raise ValueError("IO self-test mask must be exactly 8 bytes")
68
+ command = _build_command_bytes(IO_SELF_TEST_OPCODE, mask)
69
+ readback = bytes([IO_SELF_TEST_DUMMY_BYTE] * IO_SELF_TEST_MASK_LEN)
70
+ return (
71
+ timo.command_payload(
72
+ command,
73
+ params={
74
+ "wait_irq": {
75
+ "edge": "leading",
76
+ "timeout_us": IO_SELF_TEST_IRQ_TIMEOUT_US,
77
+ }
78
+ },
79
+ ),
80
+ timo.command_payload(readback),
81
+ )
82
+
83
+
84
+ def mask_from_hex(value: str) -> bytes:
85
+ """Parse a hex-encoded mask and ensure it is 8 bytes long."""
86
+
87
+ trimmed = value.strip().lower()
88
+ if len(trimmed) != IO_SELF_TEST_MASK_LEN * 2:
89
+ raise ValueError("Mask must be 16 hex characters (8 bytes)")
90
+ try:
91
+ decoded = bytes.fromhex(trimmed)
92
+ except ValueError as exc:
93
+ raise ValueError("Mask must be a valid hex string") from exc
94
+ return decoded
95
+
96
+
97
+ def mask_to_hex(mask: bytes) -> str:
98
+ """Render the mask as an uppercase hex string."""
99
+
100
+ return mask.hex().upper()
101
+
102
+
103
+ def pins_from_mask(mask: bytes) -> list[int]:
104
+ """Return the 1-indexed pin numbers enabled in the bitmask."""
105
+
106
+ pins: list[int] = []
107
+ for byte_offset, byte_value in enumerate(reversed(mask)):
108
+ for bit in range(8):
109
+ if byte_value & (1 << bit):
110
+ pins.append(byte_offset * 8 + bit + 1)
111
+ return pins
112
+
113
+
114
+ def failed_pins(request_mask: bytes, result_mask: bytes) -> list[int]:
115
+ """Return sorted pin numbers that were requested but did not pass."""
116
+
117
+ requested = set(pins_from_mask(request_mask))
118
+ passed = set(pins_from_mask(result_mask))
119
+ failures = sorted(requested - passed)
120
+ return failures
@@ -0,0 +1,478 @@
1
+ #! /usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """Serial transport helpers for the Shuttle bridge."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ import secrets
9
+ import threading
10
+ from datetime import datetime, timezone
11
+ from pathlib import Path
12
+ from typing import Any, Callable, Dict, List, Optional
13
+
14
+ from collections import deque
15
+ from concurrent.futures import Future
16
+ import serial
17
+ from serial import SerialException
18
+ from .constants import DEFAULT_BAUD, DEFAULT_TIMEOUT
19
+
20
+
21
+ class ShuttleSerialError(Exception):
22
+ """Raised when serial transport encounters an unrecoverable error."""
23
+
24
+
25
+ class CommandFuture(Future):
26
+ """Future representing a pending device command."""
27
+
28
+ def __init__(self, *, cmd_id: int, timeout: Optional[float]):
29
+ super().__init__()
30
+ self.cmd_id = cmd_id
31
+ self._timer: Optional[threading.Timer] = None
32
+ if timeout and timeout > 0:
33
+ self._timer = threading.Timer(timeout, self._set_timeout)
34
+ self._timer.daemon = True
35
+ self._timer.start()
36
+
37
+ def mark_result(self, result: Dict[str, Any]) -> None:
38
+ self._cancel_timer()
39
+ if not self.done():
40
+ self.set_result(result)
41
+
42
+ def mark_exception(self, exc: BaseException) -> None:
43
+ self._cancel_timer()
44
+ if not self.done():
45
+ self.set_exception(exc)
46
+
47
+ def _set_timeout(self) -> None:
48
+ self.mark_exception(ShuttleSerialError("Timed out waiting for device response"))
49
+
50
+ def _cancel_timer(self) -> None:
51
+ if self._timer is not None:
52
+ self._timer.cancel()
53
+ self._timer = None
54
+
55
+
56
+ class EventSubscription:
57
+ """Multiple-use future that resolves every time an event arrives."""
58
+
59
+ def __init__(
60
+ self,
61
+ event: str,
62
+ teardown: Optional[Callable[["EventSubscription"], None]] = None,
63
+ ):
64
+ self.event = event
65
+ self._teardown = teardown
66
+ self._lock = threading.Lock()
67
+ self._queue = deque()
68
+ self._current: Future = Future()
69
+ self._closed = False
70
+
71
+ def next(self, timeout: Optional[float] = None) -> Dict[str, Any]:
72
+ """Block until the next event payload is available."""
73
+
74
+ with self._lock:
75
+ if self._current.done():
76
+ result = self._current.result()
77
+ self._advance_locked()
78
+ return result
79
+ if self._queue:
80
+ payload = self._queue.popleft()
81
+ self._advance_locked()
82
+ return payload
83
+ current = self._current
84
+ try:
85
+ return current.result(timeout=timeout)
86
+ finally:
87
+ if current.done():
88
+ self._advance()
89
+
90
+ def future(self) -> Future:
91
+ with self._lock:
92
+ if self._current.done():
93
+ return self._current
94
+ if self._queue:
95
+ completed: Future = Future()
96
+ completed.set_result(self._queue.popleft())
97
+ self._advance_locked()
98
+ return completed
99
+ return self._current
100
+
101
+ def emit(self, payload: Dict[str, Any]) -> None:
102
+ with self._lock:
103
+ if self._closed:
104
+ return
105
+ if not self._current.done():
106
+ self._current.set_result(payload)
107
+ else:
108
+ self._queue.append(payload)
109
+
110
+ def fail(self, exc: BaseException) -> None:
111
+ with self._lock:
112
+ if self._closed:
113
+ return
114
+ self._closed = True
115
+ current = self._current
116
+ if not current.done():
117
+ current.set_exception(exc)
118
+ self._teardown_if_needed()
119
+
120
+ def close(self) -> None:
121
+ self.fail(ShuttleSerialError("Event subscription closed"))
122
+
123
+ def _advance(self) -> None:
124
+ with self._lock:
125
+ self._advance_locked()
126
+
127
+ def _advance_locked(self) -> None:
128
+ if self._closed:
129
+ return
130
+ if self._current.done():
131
+ if self._queue:
132
+ self._current = Future()
133
+ payload = self._queue.popleft()
134
+ self._current.set_result(payload)
135
+ else:
136
+ self._current = Future()
137
+
138
+ def _teardown_if_needed(self) -> None:
139
+ if self._teardown is not None:
140
+ self._teardown(self)
141
+
142
+
143
+ class SerialLogger:
144
+ """Persist human-readable logs of the NDJSON serial stream."""
145
+
146
+ def __init__(self, path: Path):
147
+ self._path = Path(path)
148
+ self._file = self._path.open("a", encoding="utf-8")
149
+
150
+ def log(self, direction: str, data: bytes) -> None:
151
+ timestamp = datetime.now(timezone.utc).isoformat()
152
+ text = data.decode("utf-8", errors="replace").lstrip().rstrip("\r\n")
153
+ self._file.write(f"{timestamp} {direction} {text}\n")
154
+ self._file.flush()
155
+
156
+ def close(self) -> None:
157
+ if not self._file.closed:
158
+ self._file.close()
159
+
160
+
161
+ class SequenceTracker:
162
+ """Verify monotonic `seq` values and optionally persist the last seen value."""
163
+
164
+ def __init__(self, path: Optional[Path] = None):
165
+ self._path = path
166
+ self._last_seq: Optional[int] = None
167
+ if self._path is not None:
168
+ self._initialize_from_file()
169
+
170
+ def _initialize_from_file(self) -> None:
171
+ assert self._path is not None
172
+ try:
173
+ self._path.parent.mkdir(parents=True, exist_ok=True)
174
+ except OSError as exc:
175
+ raise ValueError(
176
+ f"Unable to create sequence meta directory: {exc}"
177
+ ) from exc
178
+ if not self._path.exists():
179
+ return
180
+ try:
181
+ contents = self._path.read_text(encoding="utf-8").strip()
182
+ except OSError as exc:
183
+ raise ValueError(f"Unable to read sequence meta file: {exc}") from exc
184
+ if not contents:
185
+ return
186
+ try:
187
+ self._last_seq = int(contents)
188
+ except ValueError as exc:
189
+ raise ValueError("Sequence meta file must contain an integer") from exc
190
+
191
+ def observe(self, seq: int, *, source: str) -> None:
192
+ if self._last_seq is None:
193
+ self._last_seq = seq
194
+ self._persist()
195
+ return
196
+ expected = self._last_seq + 1
197
+ if seq != expected:
198
+ # Persist the out-of-order value so subsequent runs expect the device's current counter
199
+ self._last_seq = seq
200
+ self._persist()
201
+ raise ShuttleSerialError(
202
+ f"Detected gap in device sequence numbers (expected {expected}, got {seq}) while processing {source}"
203
+ )
204
+ self._last_seq = seq
205
+ self._persist()
206
+
207
+ def _persist(self) -> None:
208
+ if self._path is None:
209
+ return
210
+ try:
211
+ self._path.write_text(str(self._last_seq), encoding="utf-8")
212
+ except OSError as exc:
213
+ raise ShuttleSerialError(
214
+ f"Unable to write sequence meta file: {exc}"
215
+ ) from exc
216
+
217
+
218
+ class NDJSONSerialClient:
219
+ """Minimal NDJSON transport over a serial link."""
220
+
221
+ def __init__(
222
+ self,
223
+ port: str,
224
+ *,
225
+ baudrate: int = DEFAULT_BAUD,
226
+ timeout: float = DEFAULT_TIMEOUT,
227
+ logger: Optional[SerialLogger] = None,
228
+ seq_tracker: Optional[SequenceTracker] = None,
229
+ ):
230
+ try:
231
+ self._serial = serial.Serial(port=port, baudrate=baudrate, timeout=timeout)
232
+ except SerialException as exc: # pragma: no cover - hardware specific
233
+ raise ShuttleSerialError(f"Unable to open {port}: {exc}") from exc
234
+ self._serial.reset_input_buffer()
235
+ self._lock = threading.Lock()
236
+ self._pending: Dict[int, CommandFuture] = {}
237
+ self._response_backlog: Dict[int, Dict[str, Any]] = {}
238
+ self._event_listeners: Dict[str, List[EventSubscription]] = {}
239
+ self._stop_event = threading.Event()
240
+ self._response_timeout = timeout
241
+ self._logger = logger
242
+ self._seq_tracker = seq_tracker
243
+ self._reader: Optional[threading.Thread] = None
244
+
245
+ def __enter__(self) -> "NDJSONSerialClient":
246
+ return self
247
+
248
+ def __exit__(self, exc_type, exc, exc_tb) -> None:
249
+ self.close()
250
+
251
+ def close(self) -> None:
252
+ if hasattr(self, "_stop_event"):
253
+ self._stop_event.set()
254
+ if getattr(self, "_reader", None) and self._reader.is_alive():
255
+ self._reader.join(timeout=getattr(self, "_response_timeout", 0.1) or 0.1)
256
+ if hasattr(self, "_pending"):
257
+ self._fail_all(ShuttleSerialError("Serial client closed"))
258
+ if getattr(self, "_serial", None) and self._serial.is_open:
259
+ self._serial.close()
260
+
261
+ def send_command(self, op: str, params: Dict[str, Any]) -> CommandFuture:
262
+ """Send a command without blocking, returning a future for the response."""
263
+
264
+ cmd_id = self._next_cmd_id()
265
+ message: Dict[str, Any] = {"type": "cmd", "id": cmd_id, "op": op}
266
+ message.update(params)
267
+ future = CommandFuture(cmd_id=cmd_id, timeout=self._response_timeout)
268
+ with self._lock:
269
+ self._pending[cmd_id] = future
270
+
271
+ def _cleanup(_future: Future) -> None:
272
+ if hasattr(self, "_lock"):
273
+ self._remove_pending(cmd_id)
274
+ elif hasattr(self, "_pending"):
275
+ self._pending.pop(cmd_id, None)
276
+
277
+ future.add_done_callback(_cleanup)
278
+ self._write(message)
279
+
280
+ # Start reader after pending entry exists so early responses can be matched
281
+ self._ensure_reader_started()
282
+
283
+ # If a response already arrived, deliver immediately
284
+ with self._lock:
285
+ backlog = self._response_backlog.pop(cmd_id, None)
286
+ if backlog is not None:
287
+ future.mark_result(backlog)
288
+ return future
289
+
290
+ def register_event_listener(self, event: str) -> EventSubscription:
291
+ """Subscribe to a device event; each emission resolves the subscription future."""
292
+
293
+ def teardown(listener: EventSubscription) -> None:
294
+ with self._lock:
295
+ listeners = self._event_listeners.get(event, [])
296
+ if listener in listeners:
297
+ listeners.remove(listener)
298
+ if not listeners and event in self._event_listeners:
299
+ self._event_listeners.pop(event, None)
300
+
301
+ listener = EventSubscription(event, teardown=teardown)
302
+ with self._lock:
303
+ self._event_listeners.setdefault(event, []).append(listener)
304
+ self._ensure_reader_started()
305
+ return listener
306
+
307
+ def spi_xfer(
308
+ self, *, tx: str, n: Optional[int] = None, **overrides: Any
309
+ ) -> Dict[str, Any]:
310
+ payload: Dict[str, Any] = {"tx": tx}
311
+ payload.update(overrides)
312
+ payload["n"] = n if n is not None else len(tx) // 2
313
+ return self._command("spi.xfer", payload)
314
+
315
+ def spi_enable(self) -> Dict[str, Any]:
316
+ return self._command("spi.enable", {})
317
+
318
+ def spi_disable(self) -> Dict[str, Any]:
319
+ return self._command("spi.disable", {})
320
+
321
+ def get_info(self) -> Dict[str, Any]:
322
+ return self._command("get.info", {})
323
+
324
+ def ping(self) -> Dict[str, Any]:
325
+ return self._command("ping", {})
326
+
327
+ def spi_cfg(self, spi: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
328
+ payload: Dict[str, Any] = {}
329
+ if spi:
330
+ payload["spi"] = spi
331
+ return self._command("spi.cfg", payload)
332
+
333
+ def uart_cfg(self, uart: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
334
+ payload: Dict[str, Any] = {}
335
+ if uart:
336
+ payload["uart"] = uart
337
+ return self._command("uart.cfg", payload)
338
+
339
+ def uart_tx(self, data: str, port: Optional[int] = None) -> Dict[str, Any]:
340
+ payload: Dict[str, Any] = {"data": data}
341
+ if port is not None:
342
+ payload["port"] = port
343
+ return self._command("uart.tx", payload)
344
+
345
+ def _command(self, op: str, params: Dict[str, Any]) -> Dict[str, Any]:
346
+ future = self.send_command(op, params)
347
+ return future.result()
348
+
349
+ def _next_cmd_id(self) -> int:
350
+ with self._lock:
351
+ if self._response_backlog:
352
+ return next(iter(self._response_backlog))
353
+ while True:
354
+ candidate = secrets.randbits(16)
355
+ with self._lock:
356
+ if (
357
+ candidate != 0
358
+ and candidate not in self._pending
359
+ and candidate not in self._response_backlog
360
+ ):
361
+ return candidate
362
+
363
+ def _remove_pending(self, cmd_id: int) -> None:
364
+ with self._lock:
365
+ self._pending.pop(cmd_id, None)
366
+
367
+ def _reader_loop(self) -> None:
368
+ while not self._stop_event.is_set():
369
+ with self._lock:
370
+ if not self._pending and not self._event_listeners:
371
+ break
372
+ try:
373
+ message = self._read()
374
+ except ShuttleSerialError as exc:
375
+ self._fail_all(exc)
376
+ break
377
+ if message is None:
378
+ continue
379
+ try:
380
+ self._dispatch(message)
381
+ except ShuttleSerialError as exc:
382
+ self._fail_all(exc)
383
+ return
384
+
385
+ def _ensure_reader_started(self) -> None:
386
+ if self._reader is None or not self._reader.is_alive():
387
+ self._reader = threading.Thread(target=self._reader_loop, daemon=True)
388
+ self._reader.start()
389
+
390
+ def _write(self, message: Dict[str, Any]) -> None:
391
+ serialized = json.dumps(message, separators=(",", ":"))
392
+ payload = serialized.encode("utf-8") + b"\n"
393
+ with self._lock:
394
+ self._serial.write(payload)
395
+ self._log_serial("TX", payload)
396
+
397
+ def _read(self) -> Optional[Dict[str, Any]]:
398
+ try:
399
+ line = self._serial.readline()
400
+ except SerialException as exc: # pragma: no cover - hardware specific
401
+ raise ShuttleSerialError(f"Serial read failed: {exc}") from exc
402
+ if not line:
403
+ return None
404
+ self._log_serial("RX", line)
405
+ stripped = line.strip()
406
+ if not stripped:
407
+ return None
408
+ try:
409
+ decoded = stripped.decode("utf-8")
410
+ except UnicodeDecodeError as exc:
411
+ raise ShuttleSerialError(f"Invalid UTF-8 from device: {exc}") from exc
412
+ try:
413
+ message = json.loads(decoded)
414
+ except json.JSONDecodeError as exc:
415
+ raise ShuttleSerialError(
416
+ f"Invalid JSON from device: {decoded} ({exc})"
417
+ ) from exc
418
+ self._record_sequence(message)
419
+ return message
420
+
421
+ def _dispatch(self, message: Dict[str, Any]) -> None:
422
+ mtype = message.get("type")
423
+ if mtype == "resp":
424
+ resp_id = message.get("id")
425
+ if resp_id is None:
426
+ raise ShuttleSerialError("Device response missing id field")
427
+ with self._lock:
428
+ future = self._pending.pop(resp_id, None)
429
+ if future is None:
430
+ with self._lock:
431
+ self._response_backlog[resp_id] = message
432
+ return
433
+ future.mark_result(message)
434
+ elif mtype == "ev":
435
+ ev_name = message.get("ev")
436
+ if not isinstance(ev_name, str):
437
+ raise ShuttleSerialError("Device event missing ev field")
438
+ with self._lock:
439
+ listeners = list(self._event_listeners.get(ev_name, []))
440
+ for listener in listeners:
441
+ listener.emit(message)
442
+ else:
443
+ raise ShuttleSerialError(f"Received unexpected payload: {message}")
444
+
445
+ def _log_serial(self, direction: str, payload: bytes) -> None:
446
+ logger = getattr(self, "_logger", None)
447
+ if logger is not None:
448
+ logger.log(direction, payload)
449
+
450
+ def _record_sequence(self, message: Dict[str, Any]) -> None:
451
+ tracker = getattr(self, "_seq_tracker", None)
452
+ if tracker is None:
453
+ return
454
+ seq_value = message.get("seq")
455
+ if not isinstance(seq_value, int):
456
+ return
457
+ mtype = message.get("type", "?")
458
+ if mtype == "resp" and "id" in message:
459
+ source = f"response id={message['id']}"
460
+ elif mtype == "ev" and "ev" in message:
461
+ source = f"event {message['ev']}"
462
+ else:
463
+ source = mtype
464
+ tracker.observe(seq_value, source=source)
465
+
466
+ def _fail_all(self, exc: BaseException) -> None:
467
+ with self._lock:
468
+ pending = list(self._pending.values())
469
+ self._pending.clear()
470
+ self._response_backlog.clear()
471
+ listeners: List[EventSubscription] = []
472
+ for group in self._event_listeners.values():
473
+ listeners.extend(group)
474
+ self._event_listeners.clear()
475
+ for future in pending:
476
+ future.mark_exception(exc)
477
+ for listener in listeners:
478
+ listener.fail(exc)