swarmit 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dotbot-firmware/doc/sphinx/conf.py +191 -0
- {swarmit-0.2.0.dist-info → swarmit-0.3.0.dist-info}/METADATA +4 -2
- swarmit-0.3.0.dist-info/RECORD +12 -0
- {swarmit-0.2.0.dist-info → swarmit-0.3.0.dist-info}/WHEEL +1 -1
- testbed/cli/main.py +293 -527
- testbed/swarmit/__init__.py +1 -0
- testbed/swarmit/adapter.py +94 -0
- testbed/swarmit/controller.py +620 -0
- testbed/swarmit/protocol.py +292 -0
- swarmit-0.2.0.dist-info/RECORD +0 -7
- {swarmit-0.2.0.dist-info → swarmit-0.3.0.dist-info}/entry_points.txt +0 -0
- {swarmit-0.2.0.dist-info → swarmit-0.3.0.dist-info}/licenses/AUTHORS +0 -0
- {swarmit-0.2.0.dist-info → swarmit-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,620 @@
|
|
1
|
+
"""Module containing the swarmit controller class."""
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
import time
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from typing import Optional
|
7
|
+
|
8
|
+
import serial
|
9
|
+
from cryptography.hazmat.primitives import hashes
|
10
|
+
from dotbot.logger import LOGGER
|
11
|
+
from dotbot.protocol import Frame, Header
|
12
|
+
from dotbot.serial_interface import SerialInterfaceException, get_default_port
|
13
|
+
from rich import print
|
14
|
+
from rich.console import Console
|
15
|
+
from rich.live import Live
|
16
|
+
from rich.table import Table
|
17
|
+
from tqdm import tqdm
|
18
|
+
|
19
|
+
from testbed.swarmit.adapter import (
|
20
|
+
GatewayAdapterBase,
|
21
|
+
MQTTAdapter,
|
22
|
+
SerialAdapter,
|
23
|
+
)
|
24
|
+
from testbed.swarmit.protocol import (
|
25
|
+
PayloadMessage,
|
26
|
+
PayloadOTAChunkRequest,
|
27
|
+
PayloadOTAStartRequest,
|
28
|
+
PayloadResetRequest,
|
29
|
+
PayloadStartRequest,
|
30
|
+
PayloadStatusRequest,
|
31
|
+
PayloadStopRequest,
|
32
|
+
StatusType,
|
33
|
+
SwarmitPayloadType,
|
34
|
+
register_parsers,
|
35
|
+
)
|
36
|
+
|
37
|
+
CHUNK_SIZE = 128
|
38
|
+
SERIAL_PORT_DEFAULT = get_default_port()
|
39
|
+
|
40
|
+
|
41
|
+
@dataclass
|
42
|
+
class DataChunk:
|
43
|
+
"""Class that holds data chunks."""
|
44
|
+
|
45
|
+
index: int
|
46
|
+
size: int
|
47
|
+
data: bytes
|
48
|
+
|
49
|
+
|
50
|
+
@dataclass
|
51
|
+
class StartOtaData:
|
52
|
+
"""Class that holds start ota data."""
|
53
|
+
|
54
|
+
chunks: int = 0
|
55
|
+
fw_hash: bytes = b""
|
56
|
+
ids: list[str] = dataclasses.field(default_factory=lambda: [])
|
57
|
+
|
58
|
+
|
59
|
+
@dataclass
|
60
|
+
class TransferDataStatus:
|
61
|
+
"""Class that holds transfer data status for a single device."""
|
62
|
+
|
63
|
+
retries: list[int] = dataclasses.field(default_factory=lambda: [])
|
64
|
+
chunks_acked: set[int] = dataclasses.field(default_factory=lambda: set())
|
65
|
+
hashes_match: bool = False
|
66
|
+
|
67
|
+
|
68
|
+
@dataclass
|
69
|
+
class ResetLocation:
|
70
|
+
"""Class that holds reset location."""
|
71
|
+
|
72
|
+
pos_x: int = 0
|
73
|
+
pos_y: int = 0
|
74
|
+
|
75
|
+
def __repr__(self):
|
76
|
+
return f"(x={self.pos_x}, y={self.pos_y})"
|
77
|
+
|
78
|
+
|
79
|
+
def print_status(status_data: dict[str, StatusType]) -> None:
|
80
|
+
"""Print the status of the devices."""
|
81
|
+
print()
|
82
|
+
print(
|
83
|
+
f"{len(status_data)} device{'s' if len(status_data) > 1 else ''} found"
|
84
|
+
)
|
85
|
+
print()
|
86
|
+
status_table = Table()
|
87
|
+
status_table.add_column("Device ID", style="magenta", no_wrap=True)
|
88
|
+
status_table.add_column("Status", style="green", justify="center")
|
89
|
+
with Live(status_table, refresh_per_second=4) as live:
|
90
|
+
live.update(status_table)
|
91
|
+
for device_id, status in sorted(status_data.items()):
|
92
|
+
status_table.add_row(
|
93
|
+
f"{device_id}",
|
94
|
+
f'{"[bold cyan]" if status == StatusType.Running else "[bold green]"}{status.name}',
|
95
|
+
)
|
96
|
+
|
97
|
+
|
98
|
+
def print_start_status(
|
99
|
+
stopped_data: list[str], not_started: list[str]
|
100
|
+
) -> None:
|
101
|
+
"""Print the start status."""
|
102
|
+
print("[bold]Start status:[/]")
|
103
|
+
status_table = Table()
|
104
|
+
status_table.add_column("Device ID", style="magenta", no_wrap=True)
|
105
|
+
status_table.add_column("Status", style="green", justify="center")
|
106
|
+
with Live(status_table, refresh_per_second=4) as live:
|
107
|
+
live.update(status_table)
|
108
|
+
for device_id in sorted(stopped_data):
|
109
|
+
status_table.add_row(
|
110
|
+
f"{device_id}", "[bold green]:heavy_check_mark:[/]"
|
111
|
+
)
|
112
|
+
for device_id in sorted(not_started):
|
113
|
+
status_table.add_row(f"{device_id}", "[bold red]:x:[/]")
|
114
|
+
|
115
|
+
|
116
|
+
def print_stop_status(stopped_data: list[str], not_stopped: list[str]) -> None:
|
117
|
+
"""Print the stop status."""
|
118
|
+
print("[bold]Stop status:[/]")
|
119
|
+
status_table = Table()
|
120
|
+
status_table.add_column("Device ID", style="magenta", no_wrap=True)
|
121
|
+
status_table.add_column("Status", style="green", justify="center")
|
122
|
+
with Live(status_table, refresh_per_second=4) as live:
|
123
|
+
live.update(status_table)
|
124
|
+
for device_id in sorted(stopped_data):
|
125
|
+
status_table.add_row(
|
126
|
+
f"{device_id}", "[bold green]:heavy_check_mark:[/]"
|
127
|
+
)
|
128
|
+
for device_id in sorted(not_stopped):
|
129
|
+
status_table.add_row(f"{device_id}", "[bold red]:x:[/]")
|
130
|
+
|
131
|
+
|
132
|
+
def print_transfer_status(
|
133
|
+
status: dict[str, TransferDataStatus], start_data: int
|
134
|
+
) -> None:
|
135
|
+
"""Print the transfer status."""
|
136
|
+
print()
|
137
|
+
print("[bold]Transfer status:[/]")
|
138
|
+
transfer_status_table = Table()
|
139
|
+
transfer_status_table.add_column(
|
140
|
+
"Device ID", style="magenta", no_wrap=True
|
141
|
+
)
|
142
|
+
transfer_status_table.add_column(
|
143
|
+
"Chunks acked", style="green", justify="center"
|
144
|
+
)
|
145
|
+
transfer_status_table.add_column(
|
146
|
+
"Hashes match", style="green", justify="center"
|
147
|
+
)
|
148
|
+
with Live(transfer_status_table, refresh_per_second=4) as live:
|
149
|
+
live.update(transfer_status_table)
|
150
|
+
for device_id, status in sorted(status.items()):
|
151
|
+
start_marker, stop_marker = (
|
152
|
+
("[bold green]", "[/]")
|
153
|
+
if bool(status.hashes_match) is True
|
154
|
+
else ("[bold red]", "[/]")
|
155
|
+
)
|
156
|
+
transfer_status_table.add_row(
|
157
|
+
f"{device_id}",
|
158
|
+
f"{len(status.chunks_acked)}/{start_data.chunks}",
|
159
|
+
f"{start_marker}{bool(status.hashes_match)}{stop_marker}",
|
160
|
+
)
|
161
|
+
|
162
|
+
|
163
|
+
def wait_for_done(timeout, condition_func):
|
164
|
+
"""Wait for the condition to be met."""
|
165
|
+
while timeout > 0:
|
166
|
+
if condition_func():
|
167
|
+
return True
|
168
|
+
timeout -= 0.01
|
169
|
+
time.sleep(0.01)
|
170
|
+
return False
|
171
|
+
|
172
|
+
|
173
|
+
@dataclass
|
174
|
+
class ControllerSettings:
|
175
|
+
"""Class that holds controller settings."""
|
176
|
+
|
177
|
+
serial_port: str = SERIAL_PORT_DEFAULT
|
178
|
+
serial_baudrate: int = 1000000
|
179
|
+
mqtt_host: str = "argus.paris.inria.fr"
|
180
|
+
mqtt_port: int = 8883
|
181
|
+
edge: bool = False
|
182
|
+
devices: list[str] = dataclasses.field(default_factory=lambda: [])
|
183
|
+
|
184
|
+
|
185
|
+
class Controller:
|
186
|
+
"""Class used to control a swarm testbed."""
|
187
|
+
|
188
|
+
def __init__(self, settings: ControllerSettings):
|
189
|
+
self.logger = LOGGER.bind(context=__name__)
|
190
|
+
self.settings = settings
|
191
|
+
self._interface: GatewayAdapterBase = None
|
192
|
+
self.status_data: dict[str, StatusType] = {}
|
193
|
+
self.started_data: list[str] = []
|
194
|
+
self.stopped_data: list[str] = []
|
195
|
+
self.chunks: list[DataChunk] = []
|
196
|
+
self.start_ota_data: StartOtaData = StartOtaData()
|
197
|
+
self.transfer_data: dict[str, TransferDataStatus] = {}
|
198
|
+
self._known_devices: dict[str, StatusType] = {}
|
199
|
+
self.expected_reply: Optional[SwarmitPayloadType] = None
|
200
|
+
register_parsers()
|
201
|
+
if self.settings.edge is True:
|
202
|
+
self._interface = MQTTAdapter(
|
203
|
+
self.settings.mqtt_host, self.settings.mqtt_port
|
204
|
+
)
|
205
|
+
else:
|
206
|
+
try:
|
207
|
+
self._interface = SerialAdapter(
|
208
|
+
self.settings.serial_port, self.settings.serial_baudrate
|
209
|
+
)
|
210
|
+
except (
|
211
|
+
SerialInterfaceException,
|
212
|
+
serial.serialutil.SerialException,
|
213
|
+
) as exc:
|
214
|
+
console = Console()
|
215
|
+
console.print(f"[bold red]Error:[/] {exc}")
|
216
|
+
self._interface.init(self.on_data_received)
|
217
|
+
|
218
|
+
@property
|
219
|
+
def known_devices(self) -> dict[str, StatusType]:
|
220
|
+
"""Return the known devices."""
|
221
|
+
if not self._known_devices:
|
222
|
+
self._known_devices = self.status()
|
223
|
+
return self._known_devices
|
224
|
+
|
225
|
+
@property
|
226
|
+
def running_devices(self) -> list[str]:
|
227
|
+
"""Return the running devices."""
|
228
|
+
return [
|
229
|
+
device_id
|
230
|
+
for device_id, status in self.known_devices.items()
|
231
|
+
if (
|
232
|
+
status == StatusType.Running
|
233
|
+
and (
|
234
|
+
not self.settings.devices
|
235
|
+
or device_id in self.settings.devices
|
236
|
+
)
|
237
|
+
)
|
238
|
+
]
|
239
|
+
|
240
|
+
@property
|
241
|
+
def resetting_devices(self) -> list[str]:
|
242
|
+
"""Return the resetting devices."""
|
243
|
+
return [
|
244
|
+
device_id
|
245
|
+
for device_id, status in self.known_devices.items()
|
246
|
+
if (
|
247
|
+
status == StatusType.Resetting
|
248
|
+
and (
|
249
|
+
not self.settings.devices
|
250
|
+
or device_id in self.settings.devices
|
251
|
+
)
|
252
|
+
)
|
253
|
+
]
|
254
|
+
|
255
|
+
@property
|
256
|
+
def ready_devices(self) -> list[str]:
|
257
|
+
"""Return the ready devices."""
|
258
|
+
return [
|
259
|
+
device_id
|
260
|
+
for device_id, status in self.known_devices.items()
|
261
|
+
if (
|
262
|
+
status == StatusType.Bootloader
|
263
|
+
and (
|
264
|
+
not self.settings.devices
|
265
|
+
or device_id in self.settings.devices
|
266
|
+
)
|
267
|
+
)
|
268
|
+
]
|
269
|
+
|
270
|
+
@property
|
271
|
+
def interface(self) -> GatewayAdapterBase:
|
272
|
+
"""Return the interface."""
|
273
|
+
return self._interface
|
274
|
+
|
275
|
+
def terminate(self):
|
276
|
+
"""Terminate the controller."""
|
277
|
+
self.interface.close()
|
278
|
+
|
279
|
+
def send_frame(self, frame: Frame):
|
280
|
+
"""Send a frame to the devices."""
|
281
|
+
self.interface.send_data(frame.to_bytes())
|
282
|
+
|
283
|
+
def on_data_received(self, data):
|
284
|
+
frame = Frame().from_bytes(data)
|
285
|
+
if frame.payload_type < SwarmitPayloadType.SWARMIT_REQUEST_STATUS:
|
286
|
+
return
|
287
|
+
device_id = f"{frame.payload.device_id:08X}"
|
288
|
+
if (
|
289
|
+
frame.payload_type
|
290
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
|
291
|
+
and self.expected_reply
|
292
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
|
293
|
+
):
|
294
|
+
self.status_data.update(
|
295
|
+
{device_id: StatusType(frame.payload.status)}
|
296
|
+
)
|
297
|
+
elif (
|
298
|
+
frame.payload_type
|
299
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_STARTED
|
300
|
+
and self.expected_reply
|
301
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_STARTED
|
302
|
+
):
|
303
|
+
if device_id not in self.started_data:
|
304
|
+
self.started_data.append(device_id)
|
305
|
+
elif (
|
306
|
+
frame.payload_type
|
307
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_STOPPED
|
308
|
+
and self.expected_reply
|
309
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_STOPPED
|
310
|
+
):
|
311
|
+
if device_id not in self.stopped_data:
|
312
|
+
self.stopped_data.append(device_id)
|
313
|
+
elif (
|
314
|
+
frame.payload_type
|
315
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
|
316
|
+
and self.expected_reply
|
317
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
|
318
|
+
):
|
319
|
+
if device_id not in self.start_ota_data.ids:
|
320
|
+
self.start_ota_data.ids.append(device_id)
|
321
|
+
elif (
|
322
|
+
frame.payload_type
|
323
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_CHUNK_ACK
|
324
|
+
):
|
325
|
+
if (
|
326
|
+
frame.payload.index
|
327
|
+
not in self.transfer_data[device_id].chunks_acked
|
328
|
+
):
|
329
|
+
self.transfer_data[device_id].chunks_acked.add(
|
330
|
+
frame.payload.index
|
331
|
+
)
|
332
|
+
self.transfer_data[device_id].hashes_match = (
|
333
|
+
frame.payload.hashes_match
|
334
|
+
)
|
335
|
+
elif frame.payload_type in [
|
336
|
+
SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_GPIO,
|
337
|
+
SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_LOG,
|
338
|
+
]:
|
339
|
+
if (
|
340
|
+
self.settings.devices
|
341
|
+
and device_id not in self.settings.devices
|
342
|
+
):
|
343
|
+
return
|
344
|
+
logger = self.logger.bind(
|
345
|
+
deviceid=device_id,
|
346
|
+
notification=frame.payload_type.name,
|
347
|
+
timestamp=frame.payload.timestamp,
|
348
|
+
data_size=frame.payload.count,
|
349
|
+
data=frame.payload.data,
|
350
|
+
)
|
351
|
+
if (
|
352
|
+
frame.payload_type
|
353
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_GPIO
|
354
|
+
):
|
355
|
+
logger.info("GPIO event")
|
356
|
+
elif (
|
357
|
+
frame.payload_type
|
358
|
+
== SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_LOG
|
359
|
+
):
|
360
|
+
logger.info("LOG event")
|
361
|
+
elif frame.payload_type != self.expected_reply:
|
362
|
+
self.logger.warning(
|
363
|
+
"Unexpected payload",
|
364
|
+
payload_type=hex(frame.payload_type),
|
365
|
+
expected=hex(self.expected_reply),
|
366
|
+
)
|
367
|
+
else:
|
368
|
+
self.logger.error(
|
369
|
+
"Unknown payload type", payload_type=frame.payload_type
|
370
|
+
)
|
371
|
+
|
372
|
+
def status(self):
|
373
|
+
"""Request the status of the testbed."""
|
374
|
+
self.status_data: dict[str, StatusType] = {}
|
375
|
+
payload = PayloadStatusRequest(device_id=0)
|
376
|
+
frame = Frame(header=Header(), payload=payload)
|
377
|
+
self.expected_reply = SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
|
378
|
+
self.send_frame(frame)
|
379
|
+
wait_for_done(1, lambda: False)
|
380
|
+
return self.status_data
|
381
|
+
|
382
|
+
def _send_start(self, device_id: str):
|
383
|
+
def is_started():
|
384
|
+
if device_id == "0":
|
385
|
+
return sorted(self.started_data) == sorted(self.ready_devices)
|
386
|
+
else:
|
387
|
+
return device_id in self.started_data
|
388
|
+
|
389
|
+
self.expected_reply = SwarmitPayloadType.SWARMIT_NOTIFICATION_STARTED
|
390
|
+
payload = PayloadStartRequest(device_id=int(device_id, base=16))
|
391
|
+
self.send_frame(Frame(header=Header(), payload=payload))
|
392
|
+
wait_for_done(3, is_started)
|
393
|
+
self.expected_reply = None
|
394
|
+
|
395
|
+
def start(self):
|
396
|
+
"""Start the application."""
|
397
|
+
self.started_data = []
|
398
|
+
ready_devices = self.ready_devices
|
399
|
+
if not self.settings.devices:
|
400
|
+
self._send_start("0")
|
401
|
+
else:
|
402
|
+
for device_id in self.settings.devices:
|
403
|
+
if device_id not in ready_devices:
|
404
|
+
continue
|
405
|
+
self._send_start(device_id)
|
406
|
+
return self.started_data
|
407
|
+
|
408
|
+
def _send_stop(self, device_id: str):
|
409
|
+
stoppable_devices = self.running_devices + self.resetting_devices
|
410
|
+
|
411
|
+
def is_stopped():
|
412
|
+
if device_id == "0":
|
413
|
+
return sorted(self.stopped_data) == sorted(stoppable_devices)
|
414
|
+
else:
|
415
|
+
return device_id in self.stopped_data
|
416
|
+
|
417
|
+
self.expected_reply = SwarmitPayloadType.SWARMIT_NOTIFICATION_STOPPED
|
418
|
+
payload = PayloadStopRequest(device_id=int(device_id, base=16))
|
419
|
+
self.send_frame(Frame(header=Header(), payload=payload))
|
420
|
+
wait_for_done(3, is_stopped)
|
421
|
+
self.expected_reply = None
|
422
|
+
|
423
|
+
def stop(self):
|
424
|
+
"""Stop the application."""
|
425
|
+
self.stopped_data = []
|
426
|
+
stoppable_devices = self.running_devices + self.resetting_devices
|
427
|
+
if not self.settings.devices:
|
428
|
+
self._send_stop("0")
|
429
|
+
else:
|
430
|
+
for device_id in self.settings.devices:
|
431
|
+
if device_id not in stoppable_devices:
|
432
|
+
continue
|
433
|
+
self._send_stop(device_id)
|
434
|
+
return self.stopped_data
|
435
|
+
|
436
|
+
def _send_reset(self, device_id: str, location: ResetLocation):
|
437
|
+
payload = PayloadResetRequest(
|
438
|
+
device_id=int(device_id, base=16),
|
439
|
+
pos_x=location.pos_x,
|
440
|
+
pos_y=location.pos_y,
|
441
|
+
)
|
442
|
+
self.send_frame(Frame(header=Header(), payload=payload))
|
443
|
+
|
444
|
+
def reset(self, locations: dict[str, ResetLocation]):
|
445
|
+
"""Reset the application."""
|
446
|
+
ready_devices = self.ready_devices
|
447
|
+
for device_id in self.settings.devices:
|
448
|
+
if device_id not in ready_devices:
|
449
|
+
continue
|
450
|
+
print(
|
451
|
+
f"Resetting device {device_id} with location {locations[device_id]}"
|
452
|
+
)
|
453
|
+
self._send_reset(device_id, locations[device_id])
|
454
|
+
|
455
|
+
def monitor(self):
|
456
|
+
"""Monitor the testbed."""
|
457
|
+
self.logger.info("Monitoring testbed")
|
458
|
+
while True:
|
459
|
+
time.sleep(0.01)
|
460
|
+
|
461
|
+
def _send_message(self, device_id, message):
|
462
|
+
payload = PayloadMessage(
|
463
|
+
device_id=int(device_id, base=16),
|
464
|
+
count=len(message),
|
465
|
+
message=message.encode(),
|
466
|
+
)
|
467
|
+
frame = Frame(header=Header(), payload=payload)
|
468
|
+
self.send_frame(frame)
|
469
|
+
|
470
|
+
def send_message(self, message):
|
471
|
+
"""Send a message to the devices."""
|
472
|
+
running_devices = self.running_devices
|
473
|
+
if not self.settings.devices:
|
474
|
+
self._send_message("0", message)
|
475
|
+
else:
|
476
|
+
for device_id in self.settings.devices:
|
477
|
+
if device_id not in running_devices:
|
478
|
+
continue
|
479
|
+
self._send_message(device_id, message)
|
480
|
+
|
481
|
+
def _send_start_ota(self, device_id: str, firmware: bytes):
|
482
|
+
|
483
|
+
def is_start_ota_acknowledged():
|
484
|
+
if device_id == "0":
|
485
|
+
return sorted(self.start_ota_data.ids) == sorted(
|
486
|
+
self.ready_devices
|
487
|
+
)
|
488
|
+
else:
|
489
|
+
return device_id in self.start_ota_data.ids
|
490
|
+
|
491
|
+
payload = PayloadOTAStartRequest(
|
492
|
+
device_id=int(device_id, base=16),
|
493
|
+
fw_length=len(firmware),
|
494
|
+
fw_chunk_count=len(self.chunks),
|
495
|
+
fw_hash=self.fw_hash,
|
496
|
+
)
|
497
|
+
self.send_frame(Frame(header=Header(), payload=payload))
|
498
|
+
wait_for_done(3, is_start_ota_acknowledged)
|
499
|
+
|
500
|
+
def start_ota(self, firmware) -> StartOtaData:
|
501
|
+
"""Start the OTA process."""
|
502
|
+
self.start_ota_data = StartOtaData()
|
503
|
+
self.chunks = []
|
504
|
+
digest = hashes.Hash(hashes.SHA256())
|
505
|
+
chunks_count = int(len(firmware) / CHUNK_SIZE) + int(
|
506
|
+
len(firmware) % CHUNK_SIZE != 0
|
507
|
+
)
|
508
|
+
for chunk_idx in range(chunks_count):
|
509
|
+
if chunk_idx == chunks_count - 1:
|
510
|
+
chunk_size = len(firmware) % CHUNK_SIZE
|
511
|
+
else:
|
512
|
+
chunk_size = CHUNK_SIZE
|
513
|
+
data = firmware[
|
514
|
+
chunk_idx * CHUNK_SIZE : chunk_idx * CHUNK_SIZE + chunk_size
|
515
|
+
]
|
516
|
+
digest.update(data)
|
517
|
+
self.chunks.append(
|
518
|
+
DataChunk(
|
519
|
+
index=chunk_idx,
|
520
|
+
size=chunk_size,
|
521
|
+
data=data,
|
522
|
+
)
|
523
|
+
)
|
524
|
+
self.fw_hash = digest.finalize()
|
525
|
+
self.expected_reply = (
|
526
|
+
SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
|
527
|
+
)
|
528
|
+
self.start_ota_data.fw_hash = self.fw_hash
|
529
|
+
self.start_ota_data.chunks = len(self.chunks)
|
530
|
+
if not self.settings.devices:
|
531
|
+
print("Broadcast start ota notification...")
|
532
|
+
self._send_start_ota("0", firmware)
|
533
|
+
else:
|
534
|
+
for device_id in self.settings.devices:
|
535
|
+
print(f"Sending start ota notification to {device_id}...")
|
536
|
+
self._send_start_ota(device_id, firmware)
|
537
|
+
self.expected_reply = None
|
538
|
+
return self.start_ota_data
|
539
|
+
|
540
|
+
def send_chunk(self, chunk, device_id: str):
|
541
|
+
|
542
|
+
def is_chunk_acknowledged():
|
543
|
+
if device_id == "0":
|
544
|
+
return sorted(self.transfer_data.keys()) == sorted(
|
545
|
+
self.ready_devices
|
546
|
+
) and all(
|
547
|
+
[
|
548
|
+
chunk.index in status.chunks_acked
|
549
|
+
for status in self.transfer_data.values()
|
550
|
+
]
|
551
|
+
)
|
552
|
+
else:
|
553
|
+
return (
|
554
|
+
device_id in self.transfer_data.keys()
|
555
|
+
and chunk.index
|
556
|
+
in self.transfer_data[device_id].chunks_acked
|
557
|
+
)
|
558
|
+
|
559
|
+
send_time = time.time()
|
560
|
+
send = True
|
561
|
+
tries = 0
|
562
|
+
while tries < 3:
|
563
|
+
if is_chunk_acknowledged():
|
564
|
+
break
|
565
|
+
if send is True:
|
566
|
+
payload = PayloadOTAChunkRequest(
|
567
|
+
device_id=int(device_id, base=16),
|
568
|
+
index=chunk.index,
|
569
|
+
count=chunk.size,
|
570
|
+
chunk=chunk.data,
|
571
|
+
)
|
572
|
+
self.send_frame(Frame(header=Header(), payload=payload))
|
573
|
+
if device_id == "0":
|
574
|
+
for device_id in self.ready_devices:
|
575
|
+
self.transfer_data[device_id].retries[
|
576
|
+
chunk.index
|
577
|
+
] = tries
|
578
|
+
else:
|
579
|
+
self.transfer_data[device_id].retries[chunk.index] = tries
|
580
|
+
tries += 1
|
581
|
+
time.sleep(0.01)
|
582
|
+
send_time = time.time()
|
583
|
+
time.sleep(0.001)
|
584
|
+
send = time.time() - send_time > 1
|
585
|
+
|
586
|
+
def transfer(self, firmware):
|
587
|
+
"""Transfer the firmware to the devices."""
|
588
|
+
data_size = len(firmware)
|
589
|
+
progress = tqdm(
|
590
|
+
range(0, data_size),
|
591
|
+
unit="B",
|
592
|
+
unit_scale=False,
|
593
|
+
colour="green",
|
594
|
+
ncols=100,
|
595
|
+
)
|
596
|
+
progress.set_description(
|
597
|
+
f"Loading firmware ({int(data_size / 1024)}kB)"
|
598
|
+
)
|
599
|
+
self.expected_reply = (
|
600
|
+
SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_CHUNK_ACK
|
601
|
+
)
|
602
|
+
self.transfer_data = {}
|
603
|
+
if not self.settings.devices:
|
604
|
+
for device_id in self.ready_devices:
|
605
|
+
self.transfer_data[device_id] = TransferDataStatus()
|
606
|
+
self.transfer_data[device_id].retries = [0] * len(self.chunks)
|
607
|
+
else:
|
608
|
+
for device_id in self.settings.devices:
|
609
|
+
self.transfer_data[device_id] = TransferDataStatus()
|
610
|
+
self.transfer_data[device_id].retries = [0] * len(self.chunks)
|
611
|
+
for chunk in self.chunks:
|
612
|
+
if not self.settings.devices:
|
613
|
+
self.send_chunk(chunk, "0")
|
614
|
+
else:
|
615
|
+
for device_id in self.settings.devices:
|
616
|
+
self.send_chunk(chunk, device_id)
|
617
|
+
progress.update(chunk.size)
|
618
|
+
progress.close()
|
619
|
+
self.expected_reply = None
|
620
|
+
return self.transfer_data
|