OpenShock-AutoFlasher 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,11 @@
1
+ """
2
+ OpenShock Auto-Flasher
3
+ Automatically flashes OpenShock hubs when plugged in
4
+ Background colors indicate status:
5
+ Blue=Waiting, Yellow=Flashing, Green=Done, Red=Error
6
+ """
7
+
8
+ from openshock_autoflasher.flasher import AutoFlasher
9
+ from openshock_autoflasher.constants import BASE_URL, BAUD_RATE
10
+
11
+ __all__ = ["AutoFlasher", "BASE_URL", "BAUD_RATE"]
@@ -0,0 +1,9 @@
1
+ """
2
+ Entry point for running openshock_autoflasher as a module
3
+ Usage: python -m openshock_autoflasher
4
+ """
5
+
6
+ from openshock_autoflasher.cli import main
7
+
8
+ if __name__ == "__main__":
9
+ main()
@@ -0,0 +1,131 @@
1
+ """
2
+ Command-line interface for OpenShock Auto-Flasher
3
+ """
4
+
5
+ import argparse
6
+ import signal
7
+ import sys
8
+ from typing import List
9
+
10
+ import requests
11
+ from rich.style import Style
12
+
13
+ from .constants import BASE_URL
14
+ from .styles import console
15
+ from .flasher import AutoFlasher
16
+
17
+
18
+ def fetch_boards_for_help(channel: str = "stable") -> List[str]:
19
+ """Fetch boards list for help text"""
20
+ try:
21
+ version_url = f"{BASE_URL}/version-{channel}.txt"
22
+ response = requests.get(version_url, timeout=5)
23
+ response.raise_for_status()
24
+ version = response.text.strip()
25
+
26
+ boards_url = f"{BASE_URL}/{version}/boards.txt"
27
+ response = requests.get(boards_url, timeout=5)
28
+ response.raise_for_status()
29
+ boards = [line.strip() for line in response.text.strip().split("\n")]
30
+ return boards
31
+ except Exception:
32
+ return ["(Unable to fetch boards list - check network connection)"]
33
+
34
+
35
+ def create_argument_parser(channel: str = "stable") -> argparse.ArgumentParser:
36
+ """Create and return the argument parser with dynamic help text"""
37
+ # Fetch boards for help text using the specified channel
38
+ boards_list = fetch_boards_for_help(channel)
39
+ boards_help = (
40
+ f"Available boards ({channel} channel):\n "
41
+ + "\n ".join(boards_list)
42
+ )
43
+
44
+ parser = argparse.ArgumentParser(
45
+ description="OpenShock Auto-Flasher",
46
+ epilog=boards_help,
47
+ formatter_class=argparse.RawDescriptionHelpFormatter,
48
+ )
49
+ parser.add_argument(
50
+ "--channel",
51
+ "-c",
52
+ choices=["stable", "beta", "develop"],
53
+ default="stable",
54
+ help="Firmware channel (default: stable)",
55
+ )
56
+ parser.add_argument(
57
+ "--board", "-b", required=True, help="Board type (required)"
58
+ )
59
+ parser.add_argument(
60
+ "--erase",
61
+ "-e",
62
+ action="store_true",
63
+ help="Erase flash before flashing",
64
+ )
65
+ parser.add_argument(
66
+ "--no-auto",
67
+ "-n",
68
+ action="store_true",
69
+ help="Disable auto-flash (just detect devices)",
70
+ )
71
+ parser.add_argument(
72
+ "--post-flash",
73
+ "-p",
74
+ action="append",
75
+ help=(
76
+ "Serial command to send to device after flashing "
77
+ "(can be specified multiple times, executed in order)"
78
+ ),
79
+ )
80
+
81
+ return parser
82
+
83
+
84
+ def main() -> None:
85
+ """Main entry point for the application"""
86
+
87
+ # Set up signal handler for clean exit on Ctrl+C
88
+ def signal_handler(sig: int, frame: object) -> None:
89
+ console.print("\n")
90
+ console.print(
91
+ "Exiting...",
92
+ style=Style(color="white"),
93
+ markup=False,
94
+ highlight=False,
95
+ )
96
+ sys.exit(0) # Clean exit with proper cleanup
97
+
98
+ signal.signal(signal.SIGINT, signal_handler)
99
+
100
+ # Parse channel early to fetch correct boards list for help text
101
+ channel = "stable"
102
+ for i, arg in enumerate(sys.argv):
103
+ if arg in ["--channel", "-c"] and i + 1 < len(sys.argv):
104
+ candidate = sys.argv[i + 1]
105
+ if candidate in ["stable", "beta", "develop"]:
106
+ channel = candidate
107
+ break
108
+
109
+ # Create parser with dynamic help
110
+ parser = create_argument_parser(channel)
111
+
112
+ # Show help if no arguments provided
113
+ if len(sys.argv) == 1:
114
+ parser.print_help()
115
+ sys.exit(0)
116
+
117
+ args = parser.parse_args()
118
+
119
+ flasher = AutoFlasher(
120
+ channel=args.channel,
121
+ board=args.board,
122
+ erase_flash=args.erase,
123
+ auto_flash=not args.no_auto,
124
+ post_flash_commands=args.post_flash or [],
125
+ )
126
+
127
+ flasher.run()
128
+
129
+
130
+ if __name__ == "__main__":
131
+ main()
@@ -0,0 +1,19 @@
1
+ """
2
+ Configuration constants for OpenShock Auto-Flasher
3
+ """
4
+
5
+ # Network and firmware settings
6
+ REQUEST_TIMEOUT: int = 30 # seconds
7
+ BASE_URL: str = "https://firmware.openshock.org"
8
+
9
+ # ESP32 flash settings
10
+ BAUD_RATE: str = "460800"
11
+ FLASH_MODE: str = "dio"
12
+ FLASH_FREQ: str = "80m"
13
+ FLASH_ADDRESS: str = "0x0000"
14
+
15
+ # Polling settings
16
+ INITIAL_POLL_INTERVAL: float = 0.5 # seconds
17
+ MAX_POLL_INTERVAL: float = 1.0 # seconds
18
+ POLL_BACKOFF_THRESHOLD: int = 10 # checks before increasing interval
19
+ DEVICE_INIT_DELAY: int = 1 # seconds to wait after device detection
@@ -0,0 +1,479 @@
1
+ """
2
+ Core flashing logic for OpenShock Auto-Flasher
3
+ """
4
+
5
+ import hashlib
6
+ import tempfile
7
+ import textwrap
8
+ import time
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from pathlib import Path
11
+ from typing import Optional, List, Set
12
+
13
+ import esptool
14
+ import requests
15
+ import serial
16
+ import serial.tools.list_ports
17
+ from rich.style import Style
18
+
19
+ from .constants import (
20
+ REQUEST_TIMEOUT,
21
+ BASE_URL,
22
+ BAUD_RATE,
23
+ FLASH_MODE,
24
+ FLASH_FREQ,
25
+ FLASH_ADDRESS,
26
+ INITIAL_POLL_INTERVAL,
27
+ MAX_POLL_INTERVAL,
28
+ POLL_BACKOFF_THRESHOLD,
29
+ DEVICE_INIT_DELAY,
30
+ )
31
+ from .styles import StateColors, console
32
+
33
+
34
+ class AutoFlasher:
35
+ """Main auto-flashing controller for OpenShock devices"""
36
+
37
+ def __init__(
38
+ self,
39
+ channel: str = "stable",
40
+ board: Optional[str] = None,
41
+ erase_flash: bool = False,
42
+ auto_flash: bool = True,
43
+ post_flash_commands: Optional[List[str]] = None,
44
+ ) -> None:
45
+ self.channel: str = channel
46
+ self.board: Optional[str] = board
47
+ self.erase_flash: bool = erase_flash
48
+ self.auto_flash: bool = auto_flash
49
+ self.post_flash_commands: List[str] = (
50
+ post_flash_commands or []
51
+ )
52
+ self.base_url: str = BASE_URL
53
+ self.known_ports: Set[str] = set()
54
+ self.state: str = "waiting"
55
+ self.current_style: Style = StateColors.WAITING
56
+ # Cache version to avoid refetching
57
+ self.version_cache: Optional[str] = None
58
+ # Cache boards list
59
+ self.boards_cache: Optional[List[str]] = None
60
+
61
+ def get_style(self) -> Style:
62
+ """Get style based on current state"""
63
+ styles = {
64
+ "waiting": StateColors.WAITING,
65
+ "flashing": StateColors.FLASHING,
66
+ "done": StateColors.DONE,
67
+ "error": StateColors.ERROR,
68
+ }
69
+ return styles.get(self.state, StateColors.WAITING)
70
+
71
+ def set_state(self, state: str) -> None:
72
+ """Change state and update terminal background"""
73
+ self.state = state
74
+ self.current_style = self.get_style()
75
+
76
+ def log(self, message: str) -> None:
77
+ """Print log message with current background"""
78
+ timestamp = time.strftime("%H:%M:%S")
79
+ # Fill entire width with background color
80
+ max_len = console.width - 2
81
+ text = f"[{timestamp}] {message}"
82
+
83
+ # Wrap text if it's too long
84
+ if len(text) > max_len:
85
+ # Wrap the text, preserving words
86
+ wrapped_lines = textwrap.wrap(
87
+ text,
88
+ width=max_len,
89
+ break_long_words=False,
90
+ break_on_hyphens=False,
91
+ )
92
+ for line in wrapped_lines:
93
+ padding = " " * (max_len - len(line))
94
+ console.print(
95
+ f"{line}{padding}",
96
+ style=self.current_style,
97
+ markup=False,
98
+ highlight=False,
99
+ )
100
+ else:
101
+ padding = " " * (max_len - len(text))
102
+ console.print(
103
+ f"{text}{padding}",
104
+ style=self.current_style,
105
+ markup=False,
106
+ highlight=False,
107
+ )
108
+
109
+ def fetch_version(self) -> str:
110
+ """Fetch latest version for the selected channel (cached)"""
111
+ if self.version_cache:
112
+ return self.version_cache
113
+ url = f"{self.base_url}/version-{self.channel}.txt"
114
+ self.log(f"Fetching version from {self.channel} channel...")
115
+ response = requests.get(url, timeout=REQUEST_TIMEOUT)
116
+ response.raise_for_status()
117
+ self.version_cache = response.text.strip()
118
+ self.log(f"Latest {self.channel} version: {self.version_cache}")
119
+ return self.version_cache
120
+
121
+ def fetch_boards(self, version: str) -> List[str]:
122
+ """Fetch available boards for a version (cached)"""
123
+ if self.boards_cache:
124
+ return self.boards_cache
125
+ url = f"{self.base_url}/{version}/boards.txt"
126
+ self.log("Fetching available boards...")
127
+ response = requests.get(url, timeout=REQUEST_TIMEOUT)
128
+ response.raise_for_status()
129
+ self.boards_cache = [
130
+ line.strip() for line in response.text.strip().split("\n")
131
+ ]
132
+ self.log(f"Available boards: {', '.join(self.boards_cache)}")
133
+ return self.boards_cache
134
+
135
+ def download_firmware(self, version: str, board: str) -> bytes:
136
+ """Download and verify firmware binary with progress"""
137
+ self.log(f"Downloading firmware for {board}...")
138
+
139
+ firmware_url = (
140
+ f"{self.base_url}/{version}/{board}/firmware.bin"
141
+ )
142
+ hash_url = (
143
+ f"{self.base_url}/{version}/{board}/hashes.sha256.txt"
144
+ )
145
+
146
+ # Parallel download of firmware and hash
147
+ def download_firmware_data():
148
+ response = requests.get(
149
+ firmware_url, stream=True, timeout=REQUEST_TIMEOUT
150
+ )
151
+ response.raise_for_status()
152
+ return response.content
153
+
154
+ def download_hash_data():
155
+ response = requests.get(
156
+ hash_url, timeout=REQUEST_TIMEOUT
157
+ )
158
+ response.raise_for_status()
159
+ return response.text
160
+
161
+ # Use thread pool to download in parallel
162
+ with ThreadPoolExecutor(max_workers=2) as executor:
163
+ firmware_future = executor.submit(download_firmware_data)
164
+ hash_future = executor.submit(download_hash_data)
165
+
166
+ firmware_data = firmware_future.result()
167
+ hash_text = hash_future.result()
168
+
169
+ # Parse hash file
170
+ expected_hash = None
171
+ for line in hash_text.strip().split("\n"):
172
+ parts = line.split()
173
+ if len(parts) >= 2:
174
+ hash_val = parts[0].strip()
175
+ filename = " ".join(parts[1:]).strip()
176
+ if filename in ["firmware.bin", "./firmware.bin"]:
177
+ expected_hash = hash_val
178
+ break
179
+
180
+ if not expected_hash:
181
+ raise ValueError("Could not find hash for firmware.bin")
182
+
183
+ # Verify hash (case-insensitive comparison)
184
+ calculated_hash = hashlib.sha256(firmware_data).hexdigest().lower()
185
+ if calculated_hash != expected_hash.lower():
186
+ raise ValueError(
187
+ f"Hash mismatch! Expected {expected_hash}, "
188
+ f"got {calculated_hash}"
189
+ )
190
+
191
+ size_bytes = len(firmware_data)
192
+ self.log(f"✓ Firmware downloaded and verified ({size_bytes} bytes)")
193
+ return firmware_data
194
+
195
+ def execute_post_flash_commands(self, port: str) -> None:
196
+ """Execute post-flash commands over serial connection"""
197
+ try:
198
+ self.log("")
199
+ self.log("=" * 60)
200
+ cmd_count = len(self.post_flash_commands)
201
+ self.log(f"Executing {cmd_count} post-flash command(s)...")
202
+ self.log("=" * 60)
203
+
204
+ # Open serial connection
205
+ # Give device time to reboot after flash
206
+ time.sleep(2)
207
+
208
+ ser = serial.Serial(port, 115200, timeout=2)
209
+ time.sleep(0.5) # Allow connection to stabilize
210
+
211
+ # Clear any buffered data
212
+ ser.reset_input_buffer()
213
+ ser.reset_output_buffer()
214
+
215
+ cmd_total = len(self.post_flash_commands)
216
+ for i, cmd in enumerate(self.post_flash_commands, 1):
217
+ self.log(f"[{i}/{cmd_total}] Sending: {cmd}")
218
+
219
+ # Send command with newline
220
+ ser.write((cmd + "\n").encode("utf-8"))
221
+ ser.flush()
222
+
223
+ # Wait a bit for command to execute
224
+ time.sleep(0.5)
225
+
226
+ # Read any response
227
+ if ser.in_waiting > 0:
228
+ response = (
229
+ ser.read(ser.in_waiting)
230
+ .decode("utf-8", errors="ignore")
231
+ .strip()
232
+ )
233
+ if response:
234
+ self.log(f"Response: {response}")
235
+
236
+ ser.close()
237
+ self.log("✓ Post-flash commands completed")
238
+ self.log("=" * 60)
239
+
240
+ except Exception as e:
241
+ self.log(
242
+ f"⚠ Warning: Post-flash command execution failed: {e}"
243
+ )
244
+ self.log("Continuing anyway...")
245
+
246
+ def flash_device(self, port: str, version: str, board: str) -> None:
247
+ """Flash firmware to device"""
248
+ temp_firmware: Optional[Path] = None
249
+ try:
250
+ self.set_state("flashing")
251
+ self.log("=" * 60)
252
+ self.log(f"Starting flash process for {board}")
253
+ self.log(f"Port: {port}")
254
+ self.log(f"Version: {version}")
255
+ self.log("=" * 60)
256
+
257
+ # Download firmware
258
+ firmware_data = self.download_firmware(version, board)
259
+
260
+ # Save firmware to temporary file
261
+ temp_file = tempfile.NamedTemporaryFile(
262
+ mode="wb",
263
+ suffix=".bin",
264
+ prefix="OpenShock_Firmware_",
265
+ delete=False,
266
+ )
267
+ temp_firmware = Path(temp_file.name)
268
+ temp_file.write(firmware_data)
269
+ temp_file.close()
270
+
271
+ # Prepare esptool arguments
272
+ args = [
273
+ "--port",
274
+ port,
275
+ "--baud",
276
+ BAUD_RATE,
277
+ "--chip",
278
+ "auto",
279
+ "write-flash",
280
+ ]
281
+
282
+ if self.erase_flash:
283
+ self.log("Erasing flash...")
284
+ erase_args = [
285
+ "--port",
286
+ port,
287
+ "--baud",
288
+ BAUD_RATE,
289
+ "erase-flash",
290
+ ]
291
+
292
+ try:
293
+ esptool.main(erase_args)
294
+ except SystemExit as e:
295
+ if e.code != 0:
296
+ raise Exception(
297
+ f"Erase failed with exit code {e.code}"
298
+ )
299
+
300
+ self.log("✓ Erase complete")
301
+
302
+ args.extend(
303
+ [
304
+ "--flash-mode",
305
+ FLASH_MODE,
306
+ "--flash-freq",
307
+ FLASH_FREQ,
308
+ "--flash-size",
309
+ "detect",
310
+ FLASH_ADDRESS,
311
+ str(temp_firmware),
312
+ ]
313
+ )
314
+
315
+ self.log("Flashing firmware...")
316
+
317
+ try:
318
+ esptool.main(args)
319
+ except SystemExit as e:
320
+ if e.code != 0:
321
+ raise Exception(
322
+ f"Flash failed with exit code {e.code}"
323
+ )
324
+
325
+ self.log("Verifying flash...")
326
+ verify_args = [
327
+ "--port",
328
+ port,
329
+ "--baud",
330
+ BAUD_RATE,
331
+ "--chip",
332
+ "auto",
333
+ "verify-flash",
334
+ FLASH_ADDRESS,
335
+ str(temp_firmware),
336
+ ]
337
+
338
+ try:
339
+ esptool.main(verify_args)
340
+ except SystemExit as e:
341
+ if e.code != 0:
342
+ raise Exception(
343
+ f"Verification failed with exit code {e.code}"
344
+ )
345
+
346
+ self.log("✓ Verification complete!")
347
+
348
+ # Execute post-flash commands if any
349
+ if self.post_flash_commands:
350
+ self.execute_post_flash_commands(port)
351
+
352
+ self.set_state("done")
353
+ self.log("✓ Flashing complete!")
354
+ self.log("=" * 60)
355
+ self.log("SUCCESS! Device flashed successfully")
356
+ self.log("=" * 60)
357
+
358
+ # Cleanup
359
+ temp_firmware.unlink(missing_ok=True)
360
+
361
+ except Exception as e:
362
+ self.set_state("error")
363
+ if temp_firmware:
364
+ temp_firmware.unlink(missing_ok=True)
365
+ self.log(f"✗ Error during flashing: {e}")
366
+ raise
367
+
368
+ def detect_new_port(self) -> Optional[List[str]]:
369
+ """Detect when a new serial port is connected"""
370
+ try:
371
+ current_ports = set(
372
+ [p.device for p in serial.tools.list_ports.comports()]
373
+ )
374
+ new_ports = current_ports - self.known_ports
375
+
376
+ if new_ports:
377
+ self.known_ports = current_ports
378
+ # Return all new ports, not just the first
379
+ return list(new_ports)
380
+
381
+ self.known_ports = current_ports
382
+ except Exception as e:
383
+ # Log port detection errors instead of silently ignoring
384
+ self.log(f"⚠ Warning: Port detection error: {e}")
385
+ return None
386
+
387
+ def run(self) -> None:
388
+ """Main run loop"""
389
+ self.set_state("waiting")
390
+
391
+ # Print header with background
392
+ self.log("OpenShock Auto-Flasher")
393
+ self.log("=" * 60)
394
+ self.log(f"Channel: {self.channel}")
395
+ self.log(f"Erase flash: {self.erase_flash}")
396
+ self.log(f"Auto-flash: {self.auto_flash}")
397
+ self.log("=" * 60)
398
+ self.log("")
399
+
400
+ # Fetch version and boards
401
+ try:
402
+ version = self.fetch_version()
403
+ boards = self.fetch_boards(version)
404
+
405
+ # Early validation: check board exists before waiting
406
+ if self.board not in boards:
407
+ self.set_state("error")
408
+ self.log(
409
+ f"Error: Board '{self.board}' not found "
410
+ "in available boards"
411
+ )
412
+ self.log(f"Available boards: {', '.join(boards)}")
413
+ return
414
+
415
+ except Exception as e:
416
+ self.set_state("error")
417
+ self.log(f"Error fetching firmware info: {e}")
418
+ return
419
+
420
+ # Initialize known ports
421
+ self.known_ports = set(
422
+ [p.device for p in serial.tools.list_ports.comports()]
423
+ )
424
+
425
+ self.set_state("waiting")
426
+ self.log("Waiting for device to be plugged in...")
427
+ self.log("(Press Ctrl+C to exit)")
428
+
429
+ try:
430
+ # Adaptive polling: start with 0.5s, back off to 1s if no activity
431
+ poll_interval = INITIAL_POLL_INTERVAL
432
+ consecutive_checks = 0
433
+
434
+ while True:
435
+ new_ports = self.detect_new_port()
436
+
437
+ if new_ports:
438
+ # Process all newly detected ports
439
+ for new_port in new_ports:
440
+ self.log(f"✓ Device detected on {new_port}")
441
+ # Reset to fast polling after device detected
442
+ poll_interval = INITIAL_POLL_INTERVAL
443
+ consecutive_checks = 0
444
+
445
+ if self.auto_flash:
446
+ # Give device time to initialize
447
+ time.sleep(DEVICE_INIT_DELAY)
448
+ self.flash_device(
449
+ new_port, version, self.board
450
+ )
451
+
452
+ if self.auto_flash:
453
+ self.log("")
454
+ self.set_state("waiting")
455
+ self.log("Waiting for next device...")
456
+ self.log("(Press Ctrl+C to exit)")
457
+ self.log("")
458
+ else:
459
+ self.log("Auto-flash disabled. Skipping...")
460
+ else:
461
+ # Gradually increase polling interval if no activity
462
+ consecutive_checks += 1
463
+ if consecutive_checks > POLL_BACKOFF_THRESHOLD:
464
+ poll_interval = min(
465
+ MAX_POLL_INTERVAL, poll_interval + 0.1
466
+ )
467
+
468
+ time.sleep(poll_interval)
469
+
470
+ except KeyboardInterrupt:
471
+ console.print("\n")
472
+ self.log("Exiting...")
473
+
474
+ except Exception as e:
475
+ self.set_state("error")
476
+ self.log(f"Fatal error: {e}")
477
+ import traceback
478
+
479
+ traceback.print_exc()
@@ -0,0 +1,2 @@
1
+ # Marker file for PEP 561.
2
+ # This package supports type hints.
@@ -0,0 +1,18 @@
1
+ """
2
+ Terminal styling and colors for OpenShock Auto-Flasher
3
+ """
4
+
5
+ from rich.console import Console
6
+ from rich.style import Style
7
+
8
+
9
+ # Color styles for different states
10
+ class StateColors:
11
+ WAITING: Style = Style(bgcolor="blue", color="white", bold=True)
12
+ FLASHING: Style = Style(bgcolor="yellow", color="black", bold=True)
13
+ DONE: Style = Style(bgcolor="green", color="black", bold=True)
14
+ ERROR: Style = Style(bgcolor="red", color="white", bold=True)
15
+
16
+
17
+ # Global console instance for rich output
18
+ console: Console = Console(force_terminal=True, color_system="truecolor")