swarmit 0.2.0__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dotbot-firmware/doc/sphinx/conf.py +191 -0
- swarmit-0.4.4.dist-info/METADATA +127 -0
- swarmit-0.4.4.dist-info/RECORD +12 -0
- {swarmit-0.2.0.dist-info → swarmit-0.4.4.dist-info}/WHEEL +1 -1
- testbed/cli/main.py +265 -529
- testbed/swarmit/__init__.py +1 -0
- testbed/swarmit/adapter.py +142 -0
- testbed/swarmit/controller.py +664 -0
- testbed/swarmit/protocol.py +231 -0
- swarmit-0.2.0.dist-info/METADATA +0 -99
- swarmit-0.2.0.dist-info/RECORD +0 -7
- {swarmit-0.2.0.dist-info → swarmit-0.4.4.dist-info}/entry_points.txt +0 -0
- {swarmit-0.2.0.dist-info → swarmit-0.4.4.dist-info}/licenses/AUTHORS +0 -0
- {swarmit-0.2.0.dist-info → swarmit-0.4.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,664 @@
|
|
1
|
+
"""Module containing the swarmit controller class."""
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
import time
|
5
|
+
from binascii import hexlify
|
6
|
+
from dataclasses import dataclass
|
7
|
+
|
8
|
+
from cryptography.hazmat.primitives import hashes
|
9
|
+
from dotbot.logger import LOGGER
|
10
|
+
from dotbot.protocol import Packet, Payload
|
11
|
+
from dotbot.serial_interface import get_default_port
|
12
|
+
from rich import print
|
13
|
+
from rich.console import Group
|
14
|
+
from rich.live import Live
|
15
|
+
from rich.table import Table
|
16
|
+
from rich.text import Text
|
17
|
+
from tqdm import tqdm
|
18
|
+
|
19
|
+
from testbed.swarmit.adapter import (
|
20
|
+
GatewayAdapterBase,
|
21
|
+
MarilibCloudAdapter,
|
22
|
+
MarilibEdgeAdapter,
|
23
|
+
)
|
24
|
+
from testbed.swarmit.protocol import (
|
25
|
+
PayloadMessage,
|
26
|
+
PayloadOTAChunkRequest,
|
27
|
+
PayloadOTAStartRequest,
|
28
|
+
PayloadResetRequest,
|
29
|
+
PayloadStartRequest,
|
30
|
+
PayloadStopRequest,
|
31
|
+
StatusType,
|
32
|
+
SwarmitPayloadType,
|
33
|
+
register_parsers,
|
34
|
+
)
|
35
|
+
|
36
|
+
CHUNK_SIZE = 64
|
37
|
+
COMMAND_TIMEOUT = 6
|
38
|
+
COMMAND_MAX_ATTEMPTS = 5
|
39
|
+
COMMAND_ATTEMPT_DELAY = 1
|
40
|
+
STATUS_TIMEOUT = 5
|
41
|
+
OTA_MAX_RETRIES_DEFAULT = 10
|
42
|
+
OTA_ACK_TIMEOUT_DEFAULT = 3
|
43
|
+
SERIAL_PORT_DEFAULT = get_default_port()
|
44
|
+
BROADCAST_ADDRESS = 0xFFFFFFFFFFFFFFFF
|
45
|
+
|
46
|
+
|
47
|
+
@dataclass
|
48
|
+
class DataChunk:
|
49
|
+
"""Class that holds data chunks."""
|
50
|
+
|
51
|
+
index: int
|
52
|
+
size: int
|
53
|
+
sha: bytes
|
54
|
+
data: bytes
|
55
|
+
|
56
|
+
|
57
|
+
@dataclass
|
58
|
+
class StartOtaData:
|
59
|
+
"""Class that holds start ota data."""
|
60
|
+
|
61
|
+
chunks: int = 0
|
62
|
+
fw_hash: bytes = b""
|
63
|
+
addrs: list[str] = dataclasses.field(default_factory=lambda: [])
|
64
|
+
retries: int = 0
|
65
|
+
|
66
|
+
|
67
|
+
@dataclass
|
68
|
+
class Chunk:
|
69
|
+
"""Class that holds chunk status."""
|
70
|
+
|
71
|
+
index: str = "0"
|
72
|
+
size: str = "0B"
|
73
|
+
acked: int = 0
|
74
|
+
retries: int = 0
|
75
|
+
|
76
|
+
def __repr__(self):
|
77
|
+
return f"{dataclasses.asdict(self)}"
|
78
|
+
|
79
|
+
|
80
|
+
@dataclass
|
81
|
+
class TransferDataStatus:
|
82
|
+
"""Class that holds transfer data status for a single device."""
|
83
|
+
|
84
|
+
chunks: list[Chunk] = dataclasses.field(default_factory=lambda: [])
|
85
|
+
success: bool = False
|
86
|
+
|
87
|
+
|
88
|
+
@dataclass
|
89
|
+
class ResetLocation:
|
90
|
+
"""Class that holds reset location."""
|
91
|
+
|
92
|
+
pos_x: int = 0
|
93
|
+
pos_y: int = 0
|
94
|
+
|
95
|
+
def __repr__(self):
|
96
|
+
return f"(x={self.pos_x}, y={self.pos_y})"
|
97
|
+
|
98
|
+
|
99
|
+
def addr_to_hex(addr: int) -> str:
|
100
|
+
"""Convert an address to its hexadecimal representation."""
|
101
|
+
return hexlify(addr.to_bytes(8, "big")).decode().upper()
|
102
|
+
|
103
|
+
|
104
|
+
def generate_status(status_data, devices=[], status_message="found"):
|
105
|
+
data = {
|
106
|
+
key: device
|
107
|
+
for key, device in status_data.items()
|
108
|
+
if (devices and key in devices) or (not devices)
|
109
|
+
}
|
110
|
+
if not data:
|
111
|
+
return Group(Text(f"\nNo device {status_message}\n"))
|
112
|
+
|
113
|
+
header = Text(
|
114
|
+
f"\n{len(data)} device{'s' if len(data) > 1 else ''} {status_message}\n"
|
115
|
+
)
|
116
|
+
|
117
|
+
table = Table()
|
118
|
+
table.add_column("Device Addr", style="magenta", no_wrap=True)
|
119
|
+
table.add_column(
|
120
|
+
"Status",
|
121
|
+
style="green",
|
122
|
+
justify="center",
|
123
|
+
width=max([len(m) for m in StatusType.__members__]),
|
124
|
+
)
|
125
|
+
for device_addr, status in sorted(data.items()):
|
126
|
+
table.add_row(
|
127
|
+
f"{device_addr}",
|
128
|
+
f"{'[bold cyan]' if status == StatusType.Running else '[bold green]'}{status.name}",
|
129
|
+
)
|
130
|
+
return Group(header, table)
|
131
|
+
|
132
|
+
|
133
|
+
def print_transfer_status(
|
134
|
+
status: dict[str, TransferDataStatus], start_data: int
|
135
|
+
) -> None:
|
136
|
+
"""Print the transfer status."""
|
137
|
+
print()
|
138
|
+
print("[bold]Transfer status:[/]")
|
139
|
+
transfer_status_table = Table()
|
140
|
+
transfer_status_table.add_column(
|
141
|
+
"Device Addr", style="magenta", no_wrap=True
|
142
|
+
)
|
143
|
+
transfer_status_table.add_column(
|
144
|
+
"Chunks acked", style="green", justify="center"
|
145
|
+
)
|
146
|
+
|
147
|
+
with Live(transfer_status_table, refresh_per_second=4) as live:
|
148
|
+
live.update(transfer_status_table)
|
149
|
+
for device_addr, status in sorted(status.items()):
|
150
|
+
chunks_col_color = "[green]" if status.success else "[bold red]"
|
151
|
+
transfer_status_table.add_row(
|
152
|
+
f"{device_addr}",
|
153
|
+
f"{chunks_col_color}{len([chunk for chunk in status.chunks if bool(chunk.acked)])}/{start_data.chunks}",
|
154
|
+
)
|
155
|
+
|
156
|
+
|
157
|
+
def wait_for_done(timeout, condition_func):
|
158
|
+
"""Wait for the condition to be met."""
|
159
|
+
while timeout > 0:
|
160
|
+
if condition_func():
|
161
|
+
return True
|
162
|
+
timeout -= 0.01
|
163
|
+
time.sleep(0.01)
|
164
|
+
return False
|
165
|
+
|
166
|
+
|
167
|
+
@dataclass
|
168
|
+
class ControllerSettings:
|
169
|
+
"""Class that holds controller settings."""
|
170
|
+
|
171
|
+
serial_port: str = SERIAL_PORT_DEFAULT
|
172
|
+
serial_baudrate: int = 1000000
|
173
|
+
mqtt_host: str = "localhost"
|
174
|
+
mqtt_port: int = 1883
|
175
|
+
mqtt_use_tls: bool = False
|
176
|
+
network_id: int = 1
|
177
|
+
adapter: str = "serial" # or "mqtt", "marilib-edge", "marilib-cloud"
|
178
|
+
devices: list[str] = dataclasses.field(default_factory=lambda: [])
|
179
|
+
ota_max_retries: int = OTA_MAX_RETRIES_DEFAULT
|
180
|
+
ota_timeout: float = OTA_ACK_TIMEOUT_DEFAULT
|
181
|
+
verbose: bool = False
|
182
|
+
|
183
|
+
|
184
|
+
class Controller:
|
185
|
+
"""Class used to control a swarm testbed."""
|
186
|
+
|
187
|
+
def __init__(self, settings: ControllerSettings):
|
188
|
+
self.logger = LOGGER.bind(context=__name__)
|
189
|
+
self.settings = settings
|
190
|
+
self._interface: GatewayAdapterBase = None
|
191
|
+
self.status_data: dict[str, StatusType] = {}
|
192
|
+
self.started_data: list[str] = []
|
193
|
+
self.stopped_data: list[str] = []
|
194
|
+
self.chunks: list[DataChunk] = []
|
195
|
+
self.start_ota_data: StartOtaData = StartOtaData()
|
196
|
+
self.transfer_data: dict[str, TransferDataStatus] = {}
|
197
|
+
self._known_devices: dict[str, StatusType] = {}
|
198
|
+
register_parsers()
|
199
|
+
if self.settings.adapter == "cloud":
|
200
|
+
self._interface = MarilibCloudAdapter(
|
201
|
+
self.settings.mqtt_host,
|
202
|
+
self.settings.mqtt_port,
|
203
|
+
self.settings.mqtt_use_tls,
|
204
|
+
self.settings.network_id,
|
205
|
+
verbose=self.settings.verbose,
|
206
|
+
)
|
207
|
+
else:
|
208
|
+
self._interface = MarilibEdgeAdapter(
|
209
|
+
self.settings.serial_port,
|
210
|
+
self.settings.serial_baudrate,
|
211
|
+
verbose=self.settings.verbose,
|
212
|
+
)
|
213
|
+
self._interface.init(self.on_frame_received)
|
214
|
+
|
215
|
+
@property
|
216
|
+
def known_devices(self) -> dict[str, StatusType]:
|
217
|
+
"""Return the known devices."""
|
218
|
+
if not self._known_devices:
|
219
|
+
wait_for_done(COMMAND_TIMEOUT, lambda: False)
|
220
|
+
self._known_devices = self.status_data
|
221
|
+
return self._known_devices
|
222
|
+
|
223
|
+
@property
|
224
|
+
def running_devices(self) -> list[str]:
|
225
|
+
"""Return the running devices."""
|
226
|
+
return [
|
227
|
+
addr
|
228
|
+
for addr, status in self.known_devices.items()
|
229
|
+
if (
|
230
|
+
(
|
231
|
+
status == StatusType.Running
|
232
|
+
or status == StatusType.Programming
|
233
|
+
)
|
234
|
+
and (
|
235
|
+
not self.settings.devices or addr in self.settings.devices
|
236
|
+
)
|
237
|
+
)
|
238
|
+
]
|
239
|
+
|
240
|
+
@property
|
241
|
+
def resetting_devices(self) -> list[str]:
|
242
|
+
"""Return the resetting devices."""
|
243
|
+
return [
|
244
|
+
device_addr
|
245
|
+
for device_addr, status in self.known_devices.items()
|
246
|
+
if (
|
247
|
+
status == StatusType.Resetting
|
248
|
+
and (
|
249
|
+
not self.settings.devices
|
250
|
+
or device_addr in self.settings.devices
|
251
|
+
)
|
252
|
+
)
|
253
|
+
]
|
254
|
+
|
255
|
+
@property
|
256
|
+
def ready_devices(self) -> list[str]:
|
257
|
+
"""Return the ready devices."""
|
258
|
+
return [
|
259
|
+
device_addr
|
260
|
+
for device_addr, status in self.known_devices.items()
|
261
|
+
if (
|
262
|
+
status == StatusType.Bootloader
|
263
|
+
and (
|
264
|
+
not self.settings.devices
|
265
|
+
or device_addr in self.settings.devices
|
266
|
+
)
|
267
|
+
)
|
268
|
+
]
|
269
|
+
|
270
|
+
@property
|
271
|
+
def interface(self) -> GatewayAdapterBase:
|
272
|
+
"""Return the interface."""
|
273
|
+
return self._interface
|
274
|
+
|
275
|
+
def terminate(self):
|
276
|
+
"""Terminate the controller."""
|
277
|
+
self.interface.close()
|
278
|
+
|
279
|
+
def send_payload(self, destination: int, payload: Payload):
|
280
|
+
"""Send a frame to the devices."""
|
281
|
+
self.interface.send_payload(destination, payload)
|
282
|
+
|
283
|
+
def on_frame_received(self, header, packet: Packet):
|
284
|
+
"""Handle the received frame."""
|
285
|
+
# if self.settings.verbose:
|
286
|
+
# print()
|
287
|
+
# print(Frame(header, packet))
|
288
|
+
if packet.payload_type < SwarmitPayloadType.SWARMIT_REQUEST_STATUS:
|
289
|
+
return
|
290
|
+
device_addr = f"{header.source:08X}"
|
291
|
+
if (
|
292
|
+
packet.payload_type
|
293
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
|
294
|
+
):
|
295
|
+
self.status_data.update(
|
296
|
+
{device_addr: StatusType(packet.payload.status)}
|
297
|
+
)
|
298
|
+
elif (
|
299
|
+
packet.payload_type
|
300
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
|
301
|
+
):
|
302
|
+
if device_addr in self.start_ota_data.addrs:
|
303
|
+
return
|
304
|
+
self.start_ota_data.addrs.append(device_addr)
|
305
|
+
elif (
|
306
|
+
packet.payload_type
|
307
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_CHUNK_ACK
|
308
|
+
):
|
309
|
+
try:
|
310
|
+
acked = bool(
|
311
|
+
self.transfer_data[device_addr]
|
312
|
+
.chunks[packet.payload.index]
|
313
|
+
.acked
|
314
|
+
)
|
315
|
+
except (IndexError, KeyError):
|
316
|
+
self.logger.warning(
|
317
|
+
"Chunk index out of range",
|
318
|
+
device_addr=device_addr,
|
319
|
+
chunk_index=packet.payload.index,
|
320
|
+
)
|
321
|
+
return
|
322
|
+
if acked is False:
|
323
|
+
self.transfer_data[device_addr].chunks[
|
324
|
+
packet.payload.index
|
325
|
+
].acked = 1
|
326
|
+
elif packet.payload_type in [
|
327
|
+
SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_GPIO,
|
328
|
+
SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_LOG,
|
329
|
+
]:
|
330
|
+
if (
|
331
|
+
self.settings.devices
|
332
|
+
and device_addr not in self.settings.devices
|
333
|
+
):
|
334
|
+
return
|
335
|
+
logger = self.logger.bind(
|
336
|
+
device_addr=device_addr,
|
337
|
+
notification=SwarmitPayloadType(packet.payload_type).name,
|
338
|
+
timestamp=packet.payload.timestamp,
|
339
|
+
data_size=packet.payload.count,
|
340
|
+
data=packet.payload.data,
|
341
|
+
)
|
342
|
+
if (
|
343
|
+
packet.payload_type
|
344
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_GPIO
|
345
|
+
):
|
346
|
+
logger.info("GPIO event")
|
347
|
+
elif (
|
348
|
+
packet.payload_type
|
349
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_LOG
|
350
|
+
):
|
351
|
+
logger.info("LOG event")
|
352
|
+
else:
|
353
|
+
self.logger.error(
|
354
|
+
"Unknown payload type", payload_type=packet.payload_type
|
355
|
+
)
|
356
|
+
|
357
|
+
def _live_status(
|
358
|
+
self, devices=[], timeout=STATUS_TIMEOUT, message="found"
|
359
|
+
):
|
360
|
+
"""Request the live status of the testbed."""
|
361
|
+
with Live(
|
362
|
+
generate_status(self.status_data, devices, status_message=message),
|
363
|
+
refresh_per_second=4,
|
364
|
+
) as live:
|
365
|
+
while timeout > 0:
|
366
|
+
live.update(
|
367
|
+
generate_status(
|
368
|
+
self.status_data, devices, status_message=message
|
369
|
+
)
|
370
|
+
)
|
371
|
+
timeout -= 0.01
|
372
|
+
time.sleep(0.01)
|
373
|
+
|
374
|
+
def status(self):
|
375
|
+
"""Request the status of the testbed."""
|
376
|
+
self._live_status(self.settings.devices)
|
377
|
+
|
378
|
+
def _send_start(self, device_addr: str):
|
379
|
+
payload = PayloadStartRequest()
|
380
|
+
self.send_payload(int(device_addr, 16), payload)
|
381
|
+
|
382
|
+
def start(self):
|
383
|
+
"""Start the application."""
|
384
|
+
ready_devices = self.ready_devices
|
385
|
+
attempts = 0
|
386
|
+
while attempts < COMMAND_MAX_ATTEMPTS and not all(
|
387
|
+
self.status_data[addr] == StatusType.Running
|
388
|
+
for addr in ready_devices
|
389
|
+
):
|
390
|
+
if not self.settings.devices:
|
391
|
+
self._send_start(addr_to_hex(BROADCAST_ADDRESS))
|
392
|
+
else:
|
393
|
+
for device_addr in self.settings.devices:
|
394
|
+
if device_addr not in ready_devices:
|
395
|
+
continue
|
396
|
+
self._send_start(device_addr)
|
397
|
+
attempts += 1
|
398
|
+
time.sleep(COMMAND_ATTEMPT_DELAY)
|
399
|
+
self._live_status(
|
400
|
+
ready_devices, timeout=COMMAND_TIMEOUT, message="to start"
|
401
|
+
)
|
402
|
+
|
403
|
+
def stop(self):
|
404
|
+
"""Stop the application."""
|
405
|
+
stoppable_devices = self.running_devices + self.resetting_devices
|
406
|
+
|
407
|
+
attempts = 0
|
408
|
+
while attempts < COMMAND_MAX_ATTEMPTS and not all(
|
409
|
+
self.status_data[addr]
|
410
|
+
in [StatusType.Stopping, StatusType.Bootloader]
|
411
|
+
for addr in stoppable_devices
|
412
|
+
):
|
413
|
+
if not self.settings.devices:
|
414
|
+
self.send_payload(BROADCAST_ADDRESS, PayloadStopRequest())
|
415
|
+
else:
|
416
|
+
for device_addr in self.settings.devices:
|
417
|
+
if (
|
418
|
+
device_addr not in stoppable_devices
|
419
|
+
or self.status_data[device_addr].status
|
420
|
+
in [StatusType.Stopping, StatusType.Bootloader]
|
421
|
+
):
|
422
|
+
continue
|
423
|
+
self.send_payload(
|
424
|
+
int(device_addr, 16), PayloadStopRequest()
|
425
|
+
)
|
426
|
+
attempts += 1
|
427
|
+
time.sleep(COMMAND_ATTEMPT_DELAY)
|
428
|
+
self._live_status(
|
429
|
+
stoppable_devices, timeout=COMMAND_TIMEOUT, message="to stop"
|
430
|
+
)
|
431
|
+
|
432
|
+
def _send_reset(self, device_addr: int, location: ResetLocation):
|
433
|
+
payload = PayloadResetRequest(
|
434
|
+
pos_x=location.pos_x,
|
435
|
+
pos_y=location.pos_y,
|
436
|
+
)
|
437
|
+
self.send_payload(device_addr, payload)
|
438
|
+
|
439
|
+
def reset(self, locations: dict[str, ResetLocation]):
|
440
|
+
"""Reset the application."""
|
441
|
+
ready_devices = self.ready_devices
|
442
|
+
for device_addr in self.settings.devices:
|
443
|
+
if device_addr not in ready_devices:
|
444
|
+
continue
|
445
|
+
print(
|
446
|
+
f"Resetting device {device_addr} with location {locations[device_addr]}"
|
447
|
+
)
|
448
|
+
self._send_reset(int(device_addr, 16), locations[device_addr])
|
449
|
+
|
450
|
+
def monitor(self):
|
451
|
+
"""Monitor the testbed."""
|
452
|
+
self.logger.info("Monitoring testbed")
|
453
|
+
while True:
|
454
|
+
time.sleep(0.01)
|
455
|
+
|
456
|
+
def _send_message(self, device_addr: int, message: str):
|
457
|
+
payload = PayloadMessage(
|
458
|
+
count=len(message),
|
459
|
+
message=message.encode(),
|
460
|
+
)
|
461
|
+
self.send_payload(device_addr, payload)
|
462
|
+
|
463
|
+
def send_message(self, message):
|
464
|
+
"""Send a message to the devices."""
|
465
|
+
running_devices = self.running_devices
|
466
|
+
if not self.settings.devices:
|
467
|
+
self._send_message(BROADCAST_ADDRESS, message)
|
468
|
+
else:
|
469
|
+
for addr in self.settings.devices:
|
470
|
+
if addr not in running_devices:
|
471
|
+
continue
|
472
|
+
self._send_message(int(addr, 16), message)
|
473
|
+
|
474
|
+
def _send_start_ota(
|
475
|
+
self, device_addr: str, devices_to_flash: set[str], firmware: bytes
|
476
|
+
):
|
477
|
+
def is_start_ota_acknowledged():
|
478
|
+
if int(device_addr, 16) == BROADCAST_ADDRESS:
|
479
|
+
return sorted(self.start_ota_data.addrs) == sorted(
|
480
|
+
devices_to_flash
|
481
|
+
)
|
482
|
+
else:
|
483
|
+
return device_addr in self.start_ota_data.addrs
|
484
|
+
|
485
|
+
payload = PayloadOTAStartRequest(
|
486
|
+
fw_length=len(firmware),
|
487
|
+
fw_chunk_count=len(self.chunks),
|
488
|
+
)
|
489
|
+
send_time = time.time()
|
490
|
+
send = True
|
491
|
+
while (
|
492
|
+
not is_start_ota_acknowledged()
|
493
|
+
and self.start_ota_data.retries <= self.settings.ota_max_retries
|
494
|
+
):
|
495
|
+
if send is True:
|
496
|
+
self.send_payload(int(device_addr, 16), payload)
|
497
|
+
send_time = time.time()
|
498
|
+
self.start_ota_data.retries += 1
|
499
|
+
time.sleep(0.001)
|
500
|
+
send = time.time() - send_time > self.settings.ota_timeout
|
501
|
+
|
502
|
+
def start_ota(self, firmware) -> StartOtaData:
|
503
|
+
"""Start the OTA process."""
|
504
|
+
self.start_ota_data = StartOtaData()
|
505
|
+
self.chunks = []
|
506
|
+
digest = hashes.Hash(hashes.SHA256())
|
507
|
+
chunks_count = int(len(firmware) / CHUNK_SIZE) + int(
|
508
|
+
len(firmware) % CHUNK_SIZE != 0
|
509
|
+
)
|
510
|
+
for chunk_idx in range(chunks_count):
|
511
|
+
if chunk_idx == chunks_count - 1:
|
512
|
+
chunk_size = len(firmware) % CHUNK_SIZE
|
513
|
+
else:
|
514
|
+
chunk_size = CHUNK_SIZE
|
515
|
+
data = firmware[
|
516
|
+
chunk_idx * CHUNK_SIZE : chunk_idx * CHUNK_SIZE + chunk_size
|
517
|
+
]
|
518
|
+
digest.update(data)
|
519
|
+
chunk_sha = hashes.Hash(hashes.SHA256())
|
520
|
+
chunk_sha.update(data)
|
521
|
+
self.chunks.append(
|
522
|
+
DataChunk(
|
523
|
+
index=chunk_idx,
|
524
|
+
size=chunk_size,
|
525
|
+
sha=chunk_sha.finalize()[
|
526
|
+
:8
|
527
|
+
], # the first 8 bytes should be enough
|
528
|
+
data=data,
|
529
|
+
)
|
530
|
+
)
|
531
|
+
self.start_ota_data.fw_hash = digest.finalize()
|
532
|
+
self.start_ota_data.chunks = len(self.chunks)
|
533
|
+
devices_to_flash = self.ready_devices
|
534
|
+
if not self.settings.devices:
|
535
|
+
print("Broadcast start ota notification...")
|
536
|
+
self._send_start_ota(
|
537
|
+
addr_to_hex(BROADCAST_ADDRESS), devices_to_flash, firmware
|
538
|
+
)
|
539
|
+
else:
|
540
|
+
for addr in devices_to_flash:
|
541
|
+
print(f"Sending start ota notification to {addr}...")
|
542
|
+
self._send_start_ota(addr, devices_to_flash, firmware)
|
543
|
+
time.sleep(0.2)
|
544
|
+
return {
|
545
|
+
"ota": self.start_ota_data,
|
546
|
+
"acked": sorted(self.start_ota_data.addrs),
|
547
|
+
"missed": sorted(
|
548
|
+
set(devices_to_flash).difference(
|
549
|
+
set(self.start_ota_data.addrs)
|
550
|
+
)
|
551
|
+
),
|
552
|
+
}
|
553
|
+
|
554
|
+
def send_chunk(
|
555
|
+
self,
|
556
|
+
chunk: DataChunk,
|
557
|
+
device_addr: str,
|
558
|
+
devices_to_flash: set[str],
|
559
|
+
):
|
560
|
+
def is_chunk_acknowledged():
|
561
|
+
if int(device_addr, 16) == BROADCAST_ADDRESS:
|
562
|
+
return sorted(self.transfer_data.keys()) == sorted(
|
563
|
+
devices_to_flash
|
564
|
+
) and all(
|
565
|
+
[
|
566
|
+
status.chunks[chunk.index].acked
|
567
|
+
for status in self.transfer_data.values()
|
568
|
+
]
|
569
|
+
)
|
570
|
+
else:
|
571
|
+
return (
|
572
|
+
device_addr in self.transfer_data.keys()
|
573
|
+
and self.transfer_data[device_addr]
|
574
|
+
.chunks[chunk.index]
|
575
|
+
.acked
|
576
|
+
)
|
577
|
+
|
578
|
+
payload = PayloadOTAChunkRequest(
|
579
|
+
index=chunk.index,
|
580
|
+
count=chunk.size,
|
581
|
+
sha=chunk.sha,
|
582
|
+
chunk=chunk.data,
|
583
|
+
)
|
584
|
+
send_time = time.time()
|
585
|
+
send = True
|
586
|
+
retries_count = 0
|
587
|
+
while (
|
588
|
+
not is_chunk_acknowledged()
|
589
|
+
and retries_count <= self.settings.ota_max_retries
|
590
|
+
):
|
591
|
+
if send is True:
|
592
|
+
if self.settings.verbose:
|
593
|
+
missing_acks = [
|
594
|
+
addr
|
595
|
+
for addr in devices_to_flash
|
596
|
+
if addr not in self.transfer_data
|
597
|
+
or not self.transfer_data[addr]
|
598
|
+
.chunks[chunk.index]
|
599
|
+
.acked
|
600
|
+
]
|
601
|
+
print(
|
602
|
+
f"Transferring chunk {chunk.index}/{len(self.start_ota_data.chunks)} to {device_addr} "
|
603
|
+
f"- {retries_count} retries "
|
604
|
+
f"- {len(missing_acks)} missing acks: {', '.join(missing_acks) if missing_acks else 'none'}"
|
605
|
+
)
|
606
|
+
self.send_payload(int(device_addr, 16), payload)
|
607
|
+
if int(device_addr, 16) == BROADCAST_ADDRESS:
|
608
|
+
for addr in devices_to_flash:
|
609
|
+
self.transfer_data[addr].chunks[
|
610
|
+
chunk.index
|
611
|
+
].retries = retries_count
|
612
|
+
else:
|
613
|
+
self.transfer_data[device_addr].chunks[
|
614
|
+
chunk.index
|
615
|
+
].retries = retries_count
|
616
|
+
send_time = time.time()
|
617
|
+
retries_count += 1
|
618
|
+
time.sleep(0.001)
|
619
|
+
send = time.time() - send_time > self.settings.ota_timeout
|
620
|
+
|
621
|
+
def transfer(self, firmware, devices) -> dict[str, TransferDataStatus]:
|
622
|
+
"""Transfer the firmware to the devices."""
|
623
|
+
data_size = len(firmware)
|
624
|
+
use_progress_bar = not self.settings.verbose
|
625
|
+
if use_progress_bar:
|
626
|
+
progress = tqdm(
|
627
|
+
range(0, data_size),
|
628
|
+
unit="B",
|
629
|
+
unit_scale=False,
|
630
|
+
colour="green",
|
631
|
+
ncols=100,
|
632
|
+
)
|
633
|
+
progress.set_description(
|
634
|
+
f"Loading firmware ({int(data_size / 1024)}kB)"
|
635
|
+
)
|
636
|
+
self.transfer_data = {}
|
637
|
+
for device_addr in devices:
|
638
|
+
self.transfer_data[device_addr] = TransferDataStatus()
|
639
|
+
self.transfer_data[device_addr].chunks = [
|
640
|
+
Chunk(index=f"{i:03d}", size=f"{self.chunks[i].size:03d}B")
|
641
|
+
for i in range(len(self.chunks))
|
642
|
+
]
|
643
|
+
for chunk in self.chunks:
|
644
|
+
if not self.settings.devices:
|
645
|
+
self.send_chunk(
|
646
|
+
chunk,
|
647
|
+
addr_to_hex(BROADCAST_ADDRESS),
|
648
|
+
devices,
|
649
|
+
)
|
650
|
+
else:
|
651
|
+
for addr in devices:
|
652
|
+
self.send_chunk(chunk, addr, devices)
|
653
|
+
if use_progress_bar:
|
654
|
+
progress.update(chunk.size)
|
655
|
+
if use_progress_bar:
|
656
|
+
progress.close()
|
657
|
+
for device in devices:
|
658
|
+
device_data = self.transfer_data.get(device)
|
659
|
+
if device_data:
|
660
|
+
device_data.success = all(
|
661
|
+
chunk.acked for chunk in device_data.chunks
|
662
|
+
)
|
663
|
+
self.transfer_data[device] = device_data
|
664
|
+
return self.transfer_data
|