swarmit 0.3.0__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- swarmit-0.4.5.dist-info/METADATA +128 -0
- swarmit-0.4.5.dist-info/RECORD +12 -0
- testbed/cli/main.py +140 -170
- testbed/swarmit/__init__.py +1 -1
- testbed/swarmit/adapter.py +112 -64
- testbed/swarmit/controller.py +350 -306
- testbed/swarmit/protocol.py +38 -99
- swarmit-0.3.0.dist-info/METADATA +0 -101
- swarmit-0.3.0.dist-info/RECORD +0 -12
- {swarmit-0.3.0.dist-info → swarmit-0.4.5.dist-info}/WHEEL +0 -0
- {swarmit-0.3.0.dist-info → swarmit-0.4.5.dist-info}/entry_points.txt +0 -0
- {swarmit-0.3.0.dist-info → swarmit-0.4.5.dist-info}/licenses/AUTHORS +0 -0
- {swarmit-0.3.0.dist-info → swarmit-0.4.5.dist-info}/licenses/LICENSE +0 -0
testbed/swarmit/controller.py
CHANGED
@@ -2,24 +2,24 @@
|
|
2
2
|
|
3
3
|
import dataclasses
|
4
4
|
import time
|
5
|
+
from binascii import hexlify
|
5
6
|
from dataclasses import dataclass
|
6
|
-
from typing import Optional
|
7
7
|
|
8
|
-
import serial
|
9
8
|
from cryptography.hazmat.primitives import hashes
|
10
9
|
from dotbot.logger import LOGGER
|
11
|
-
from dotbot.protocol import
|
12
|
-
from dotbot.serial_interface import
|
10
|
+
from dotbot.protocol import Packet, Payload
|
11
|
+
from dotbot.serial_interface import get_default_port
|
13
12
|
from rich import print
|
14
|
-
from rich.console import
|
13
|
+
from rich.console import Group
|
15
14
|
from rich.live import Live
|
16
15
|
from rich.table import Table
|
16
|
+
from rich.text import Text
|
17
17
|
from tqdm import tqdm
|
18
18
|
|
19
19
|
from testbed.swarmit.adapter import (
|
20
20
|
GatewayAdapterBase,
|
21
|
-
|
22
|
-
|
21
|
+
MarilibCloudAdapter,
|
22
|
+
MarilibEdgeAdapter,
|
23
23
|
)
|
24
24
|
from testbed.swarmit.protocol import (
|
25
25
|
PayloadMessage,
|
@@ -27,15 +27,21 @@ from testbed.swarmit.protocol import (
|
|
27
27
|
PayloadOTAStartRequest,
|
28
28
|
PayloadResetRequest,
|
29
29
|
PayloadStartRequest,
|
30
|
-
PayloadStatusRequest,
|
31
30
|
PayloadStopRequest,
|
32
31
|
StatusType,
|
33
32
|
SwarmitPayloadType,
|
34
33
|
register_parsers,
|
35
34
|
)
|
36
35
|
|
37
|
-
CHUNK_SIZE =
|
36
|
+
CHUNK_SIZE = 64
|
37
|
+
COMMAND_TIMEOUT = 6
|
38
|
+
COMMAND_MAX_ATTEMPTS = 5
|
39
|
+
COMMAND_ATTEMPT_DELAY = 1
|
40
|
+
STATUS_TIMEOUT = 5
|
41
|
+
OTA_MAX_RETRIES_DEFAULT = 10
|
42
|
+
OTA_ACK_TIMEOUT_DEFAULT = 3
|
38
43
|
SERIAL_PORT_DEFAULT = get_default_port()
|
44
|
+
BROADCAST_ADDRESS = 0xFFFFFFFFFFFFFFFF
|
39
45
|
|
40
46
|
|
41
47
|
@dataclass
|
@@ -44,6 +50,7 @@ class DataChunk:
|
|
44
50
|
|
45
51
|
index: int
|
46
52
|
size: int
|
53
|
+
sha: bytes
|
47
54
|
data: bytes
|
48
55
|
|
49
56
|
|
@@ -53,16 +60,29 @@ class StartOtaData:
|
|
53
60
|
|
54
61
|
chunks: int = 0
|
55
62
|
fw_hash: bytes = b""
|
56
|
-
|
63
|
+
addrs: list[str] = dataclasses.field(default_factory=lambda: [])
|
64
|
+
retries: int = 0
|
65
|
+
|
66
|
+
|
67
|
+
@dataclass
|
68
|
+
class Chunk:
|
69
|
+
"""Class that holds chunk status."""
|
70
|
+
|
71
|
+
index: str = "0"
|
72
|
+
size: str = "0B"
|
73
|
+
acked: int = 0
|
74
|
+
retries: int = 0
|
75
|
+
|
76
|
+
def __repr__(self):
|
77
|
+
return f"{dataclasses.asdict(self)}"
|
57
78
|
|
58
79
|
|
59
80
|
@dataclass
|
60
81
|
class TransferDataStatus:
|
61
82
|
"""Class that holds transfer data status for a single device."""
|
62
83
|
|
63
|
-
|
64
|
-
|
65
|
-
hashes_match: bool = False
|
84
|
+
chunks: list[Chunk] = dataclasses.field(default_factory=lambda: [])
|
85
|
+
success: bool = False
|
66
86
|
|
67
87
|
|
68
88
|
@dataclass
|
@@ -76,57 +96,38 @@ class ResetLocation:
|
|
76
96
|
return f"(x={self.pos_x}, y={self.pos_y})"
|
77
97
|
|
78
98
|
|
79
|
-
def
|
80
|
-
"""
|
81
|
-
|
82
|
-
print(
|
83
|
-
f"{len(status_data)} device{'s' if len(status_data) > 1 else ''} found"
|
84
|
-
)
|
85
|
-
print()
|
86
|
-
status_table = Table()
|
87
|
-
status_table.add_column("Device ID", style="magenta", no_wrap=True)
|
88
|
-
status_table.add_column("Status", style="green", justify="center")
|
89
|
-
with Live(status_table, refresh_per_second=4) as live:
|
90
|
-
live.update(status_table)
|
91
|
-
for device_id, status in sorted(status_data.items()):
|
92
|
-
status_table.add_row(
|
93
|
-
f"{device_id}",
|
94
|
-
f'{"[bold cyan]" if status == StatusType.Running else "[bold green]"}{status.name}',
|
95
|
-
)
|
99
|
+
def addr_to_hex(addr: int) -> str:
|
100
|
+
"""Convert an address to its hexadecimal representation."""
|
101
|
+
return hexlify(addr.to_bytes(8, "big")).decode().upper()
|
96
102
|
|
97
103
|
|
98
|
-
def
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
status_table.add_row(
|
126
|
-
f"{device_id}", "[bold green]:heavy_check_mark:[/]"
|
127
|
-
)
|
128
|
-
for device_id in sorted(not_stopped):
|
129
|
-
status_table.add_row(f"{device_id}", "[bold red]:x:[/]")
|
104
|
+
def generate_status(status_data, devices=[], status_message="found"):
|
105
|
+
data = {
|
106
|
+
key: device
|
107
|
+
for key, device in status_data.items()
|
108
|
+
if (devices and key in devices) or (not devices)
|
109
|
+
}
|
110
|
+
if not data:
|
111
|
+
return Group(Text(f"\nNo device {status_message}\n"))
|
112
|
+
|
113
|
+
header = Text(
|
114
|
+
f"\n{len(data)} device{'s' if len(data) > 1 else ''} {status_message}\n"
|
115
|
+
)
|
116
|
+
|
117
|
+
table = Table()
|
118
|
+
table.add_column("Device Addr", style="magenta", no_wrap=True)
|
119
|
+
table.add_column(
|
120
|
+
"Status",
|
121
|
+
style="green",
|
122
|
+
justify="center",
|
123
|
+
width=max([len(m) for m in StatusType.__members__]),
|
124
|
+
)
|
125
|
+
for device_addr, status in sorted(data.items()):
|
126
|
+
table.add_row(
|
127
|
+
f"{device_addr}",
|
128
|
+
f"{'[bold cyan]' if status == StatusType.Running else '[bold green]'}{status.name}",
|
129
|
+
)
|
130
|
+
return Group(header, table)
|
130
131
|
|
131
132
|
|
132
133
|
def print_transfer_status(
|
@@ -137,26 +138,19 @@ def print_transfer_status(
|
|
137
138
|
print("[bold]Transfer status:[/]")
|
138
139
|
transfer_status_table = Table()
|
139
140
|
transfer_status_table.add_column(
|
140
|
-
"Device
|
141
|
+
"Device Addr", style="magenta", no_wrap=True
|
141
142
|
)
|
142
143
|
transfer_status_table.add_column(
|
143
144
|
"Chunks acked", style="green", justify="center"
|
144
145
|
)
|
145
|
-
|
146
|
-
"Hashes match", style="green", justify="center"
|
147
|
-
)
|
146
|
+
|
148
147
|
with Live(transfer_status_table, refresh_per_second=4) as live:
|
149
148
|
live.update(transfer_status_table)
|
150
|
-
for
|
151
|
-
|
152
|
-
("[bold green]", "[/]")
|
153
|
-
if bool(status.hashes_match) is True
|
154
|
-
else ("[bold red]", "[/]")
|
155
|
-
)
|
149
|
+
for device_addr, status in sorted(status.items()):
|
150
|
+
chunks_col_color = "[green]" if status.success else "[bold red]"
|
156
151
|
transfer_status_table.add_row(
|
157
|
-
f"{
|
158
|
-
f"{len(status.
|
159
|
-
f"{start_marker}{bool(status.hashes_match)}{stop_marker}",
|
152
|
+
f"{device_addr}",
|
153
|
+
f"{chunks_col_color}{len([chunk for chunk in status.chunks if bool(chunk.acked)])}/{start_data.chunks}",
|
160
154
|
)
|
161
155
|
|
162
156
|
|
@@ -176,10 +170,15 @@ class ControllerSettings:
|
|
176
170
|
|
177
171
|
serial_port: str = SERIAL_PORT_DEFAULT
|
178
172
|
serial_baudrate: int = 1000000
|
179
|
-
mqtt_host: str = "
|
180
|
-
mqtt_port: int =
|
181
|
-
|
173
|
+
mqtt_host: str = "localhost"
|
174
|
+
mqtt_port: int = 1883
|
175
|
+
mqtt_use_tls: bool = False
|
176
|
+
network_id: int = 1
|
177
|
+
adapter: str = "serial" # or "mqtt", "marilib-edge", "marilib-cloud"
|
182
178
|
devices: list[str] = dataclasses.field(default_factory=lambda: [])
|
179
|
+
ota_max_retries: int = OTA_MAX_RETRIES_DEFAULT
|
180
|
+
ota_timeout: float = OTA_ACK_TIMEOUT_DEFAULT
|
181
|
+
verbose: bool = False
|
183
182
|
|
184
183
|
|
185
184
|
class Controller:
|
@@ -196,43 +195,44 @@ class Controller:
|
|
196
195
|
self.start_ota_data: StartOtaData = StartOtaData()
|
197
196
|
self.transfer_data: dict[str, TransferDataStatus] = {}
|
198
197
|
self._known_devices: dict[str, StatusType] = {}
|
199
|
-
self.expected_reply: Optional[SwarmitPayloadType] = None
|
200
198
|
register_parsers()
|
201
|
-
if self.settings.
|
202
|
-
self._interface =
|
203
|
-
self.settings.mqtt_host,
|
199
|
+
if self.settings.adapter == "cloud":
|
200
|
+
self._interface = MarilibCloudAdapter(
|
201
|
+
self.settings.mqtt_host,
|
202
|
+
self.settings.mqtt_port,
|
203
|
+
self.settings.mqtt_use_tls,
|
204
|
+
self.settings.network_id,
|
205
|
+
verbose=self.settings.verbose,
|
204
206
|
)
|
205
207
|
else:
|
206
|
-
|
207
|
-
self.
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
serial.serialutil.SerialException,
|
213
|
-
) as exc:
|
214
|
-
console = Console()
|
215
|
-
console.print(f"[bold red]Error:[/] {exc}")
|
216
|
-
self._interface.init(self.on_data_received)
|
208
|
+
self._interface = MarilibEdgeAdapter(
|
209
|
+
self.settings.serial_port,
|
210
|
+
self.settings.serial_baudrate,
|
211
|
+
verbose=self.settings.verbose,
|
212
|
+
)
|
213
|
+
self._interface.init(self.on_frame_received)
|
217
214
|
|
218
215
|
@property
|
219
216
|
def known_devices(self) -> dict[str, StatusType]:
|
220
217
|
"""Return the known devices."""
|
221
218
|
if not self._known_devices:
|
222
|
-
|
219
|
+
wait_for_done(COMMAND_TIMEOUT, lambda: False)
|
220
|
+
self._known_devices = self.status_data
|
223
221
|
return self._known_devices
|
224
222
|
|
225
223
|
@property
|
226
224
|
def running_devices(self) -> list[str]:
|
227
225
|
"""Return the running devices."""
|
228
226
|
return [
|
229
|
-
|
230
|
-
for
|
227
|
+
addr
|
228
|
+
for addr, status in self.known_devices.items()
|
231
229
|
if (
|
232
|
-
|
230
|
+
(
|
231
|
+
status == StatusType.Running
|
232
|
+
or status == StatusType.Programming
|
233
|
+
)
|
233
234
|
and (
|
234
|
-
not self.settings.devices
|
235
|
-
or device_id in self.settings.devices
|
235
|
+
not self.settings.devices or addr in self.settings.devices
|
236
236
|
)
|
237
237
|
)
|
238
238
|
]
|
@@ -241,13 +241,13 @@ class Controller:
|
|
241
241
|
def resetting_devices(self) -> list[str]:
|
242
242
|
"""Return the resetting devices."""
|
243
243
|
return [
|
244
|
-
|
245
|
-
for
|
244
|
+
device_addr
|
245
|
+
for device_addr, status in self.known_devices.items()
|
246
246
|
if (
|
247
247
|
status == StatusType.Resetting
|
248
248
|
and (
|
249
249
|
not self.settings.devices
|
250
|
-
or
|
250
|
+
or device_addr in self.settings.devices
|
251
251
|
)
|
252
252
|
)
|
253
253
|
]
|
@@ -256,13 +256,13 @@ class Controller:
|
|
256
256
|
def ready_devices(self) -> list[str]:
|
257
257
|
"""Return the ready devices."""
|
258
258
|
return [
|
259
|
-
|
260
|
-
for
|
259
|
+
device_addr
|
260
|
+
for device_addr, status in self.known_devices.items()
|
261
261
|
if (
|
262
262
|
status == StatusType.Bootloader
|
263
263
|
and (
|
264
264
|
not self.settings.devices
|
265
|
-
or
|
265
|
+
or device_addr in self.settings.devices
|
266
266
|
)
|
267
267
|
)
|
268
268
|
]
|
@@ -276,181 +276,176 @@ class Controller:
|
|
276
276
|
"""Terminate the controller."""
|
277
277
|
self.interface.close()
|
278
278
|
|
279
|
-
def
|
279
|
+
def send_payload(self, destination: int, payload: Payload):
|
280
280
|
"""Send a frame to the devices."""
|
281
|
-
self.interface.
|
282
|
-
|
283
|
-
def
|
284
|
-
|
285
|
-
if
|
281
|
+
self.interface.send_payload(destination, payload)
|
282
|
+
|
283
|
+
def on_frame_received(self, header, packet: Packet):
|
284
|
+
"""Handle the received frame."""
|
285
|
+
# if self.settings.verbose:
|
286
|
+
# print()
|
287
|
+
# print(Frame(header, packet))
|
288
|
+
if packet.payload_type < SwarmitPayloadType.SWARMIT_REQUEST_STATUS:
|
286
289
|
return
|
287
|
-
|
290
|
+
device_addr = f"{header.source:08X}"
|
288
291
|
if (
|
289
|
-
|
290
|
-
== SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
|
291
|
-
and self.expected_reply
|
292
|
+
packet.payload_type
|
292
293
|
== SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
|
293
294
|
):
|
294
295
|
self.status_data.update(
|
295
|
-
{
|
296
|
+
{device_addr: StatusType(packet.payload.status)}
|
296
297
|
)
|
297
298
|
elif (
|
298
|
-
|
299
|
-
== SwarmitPayloadType.SWARMIT_NOTIFICATION_STARTED
|
300
|
-
and self.expected_reply
|
301
|
-
== SwarmitPayloadType.SWARMIT_NOTIFICATION_STARTED
|
302
|
-
):
|
303
|
-
if device_id not in self.started_data:
|
304
|
-
self.started_data.append(device_id)
|
305
|
-
elif (
|
306
|
-
frame.payload_type
|
307
|
-
== SwarmitPayloadType.SWARMIT_NOTIFICATION_STOPPED
|
308
|
-
and self.expected_reply
|
309
|
-
== SwarmitPayloadType.SWARMIT_NOTIFICATION_STOPPED
|
310
|
-
):
|
311
|
-
if device_id not in self.stopped_data:
|
312
|
-
self.stopped_data.append(device_id)
|
313
|
-
elif (
|
314
|
-
frame.payload_type
|
315
|
-
== SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
|
316
|
-
and self.expected_reply
|
299
|
+
packet.payload_type
|
317
300
|
== SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
|
318
301
|
):
|
319
|
-
if
|
320
|
-
|
302
|
+
if device_addr in self.start_ota_data.addrs:
|
303
|
+
return
|
304
|
+
self.start_ota_data.addrs.append(device_addr)
|
321
305
|
elif (
|
322
|
-
|
306
|
+
packet.payload_type
|
323
307
|
== SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_CHUNK_ACK
|
324
308
|
):
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
frame.payload.index
|
309
|
+
try:
|
310
|
+
acked = bool(
|
311
|
+
self.transfer_data[device_addr]
|
312
|
+
.chunks[packet.payload.index]
|
313
|
+
.acked
|
331
314
|
)
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
315
|
+
except (IndexError, KeyError):
|
316
|
+
self.logger.warning(
|
317
|
+
"Chunk index out of range",
|
318
|
+
device_addr=device_addr,
|
319
|
+
chunk_index=packet.payload.index,
|
320
|
+
)
|
321
|
+
return
|
322
|
+
if acked is False:
|
323
|
+
self.transfer_data[device_addr].chunks[
|
324
|
+
packet.payload.index
|
325
|
+
].acked = 1
|
326
|
+
elif packet.payload_type in [
|
336
327
|
SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_GPIO,
|
337
328
|
SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_LOG,
|
338
329
|
]:
|
339
330
|
if (
|
340
331
|
self.settings.devices
|
341
|
-
and
|
332
|
+
and device_addr not in self.settings.devices
|
342
333
|
):
|
343
334
|
return
|
344
335
|
logger = self.logger.bind(
|
345
|
-
|
346
|
-
notification=
|
347
|
-
timestamp=
|
348
|
-
data_size=
|
349
|
-
data=
|
336
|
+
device_addr=device_addr,
|
337
|
+
notification=SwarmitPayloadType(packet.payload_type).name,
|
338
|
+
timestamp=packet.payload.timestamp,
|
339
|
+
data_size=packet.payload.count,
|
340
|
+
data=packet.payload.data,
|
350
341
|
)
|
351
342
|
if (
|
352
|
-
|
343
|
+
packet.payload_type
|
353
344
|
== SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_GPIO
|
354
345
|
):
|
355
346
|
logger.info("GPIO event")
|
356
347
|
elif (
|
357
|
-
|
348
|
+
packet.payload_type
|
358
349
|
== SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_LOG
|
359
350
|
):
|
360
351
|
logger.info("LOG event")
|
361
|
-
elif frame.payload_type != self.expected_reply:
|
362
|
-
self.logger.warning(
|
363
|
-
"Unexpected payload",
|
364
|
-
payload_type=hex(frame.payload_type),
|
365
|
-
expected=hex(self.expected_reply),
|
366
|
-
)
|
367
352
|
else:
|
368
353
|
self.logger.error(
|
369
|
-
"Unknown payload type", payload_type=
|
354
|
+
"Unknown payload type", payload_type=packet.payload_type
|
370
355
|
)
|
371
356
|
|
357
|
+
def _live_status(
|
358
|
+
self, devices=[], timeout=STATUS_TIMEOUT, message="found"
|
359
|
+
):
|
360
|
+
"""Request the live status of the testbed."""
|
361
|
+
with Live(
|
362
|
+
generate_status(self.status_data, devices, status_message=message),
|
363
|
+
refresh_per_second=4,
|
364
|
+
) as live:
|
365
|
+
while timeout > 0:
|
366
|
+
live.update(
|
367
|
+
generate_status(
|
368
|
+
self.status_data, devices, status_message=message
|
369
|
+
)
|
370
|
+
)
|
371
|
+
timeout -= 0.01
|
372
|
+
time.sleep(0.01)
|
373
|
+
|
372
374
|
def status(self):
|
373
375
|
"""Request the status of the testbed."""
|
374
|
-
self.
|
375
|
-
payload = PayloadStatusRequest(device_id=0)
|
376
|
-
frame = Frame(header=Header(), payload=payload)
|
377
|
-
self.expected_reply = SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
|
378
|
-
self.send_frame(frame)
|
379
|
-
wait_for_done(1, lambda: False)
|
380
|
-
return self.status_data
|
381
|
-
|
382
|
-
def _send_start(self, device_id: str):
|
383
|
-
def is_started():
|
384
|
-
if device_id == "0":
|
385
|
-
return sorted(self.started_data) == sorted(self.ready_devices)
|
386
|
-
else:
|
387
|
-
return device_id in self.started_data
|
376
|
+
self._live_status(self.settings.devices)
|
388
377
|
|
389
|
-
|
390
|
-
payload = PayloadStartRequest(
|
391
|
-
self.
|
392
|
-
wait_for_done(3, is_started)
|
393
|
-
self.expected_reply = None
|
378
|
+
def _send_start(self, device_addr: str):
|
379
|
+
payload = PayloadStartRequest()
|
380
|
+
self.send_payload(int(device_addr, 16), payload)
|
394
381
|
|
395
382
|
def start(self):
|
396
383
|
"""Start the application."""
|
397
|
-
self.started_data = []
|
398
384
|
ready_devices = self.ready_devices
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
for
|
403
|
-
|
404
|
-
|
405
|
-
self._send_start(
|
406
|
-
return self.started_data
|
407
|
-
|
408
|
-
def _send_stop(self, device_id: str):
|
409
|
-
stoppable_devices = self.running_devices + self.resetting_devices
|
410
|
-
|
411
|
-
def is_stopped():
|
412
|
-
if device_id == "0":
|
413
|
-
return sorted(self.stopped_data) == sorted(stoppable_devices)
|
385
|
+
attempts = 0
|
386
|
+
while attempts < COMMAND_MAX_ATTEMPTS and not all(
|
387
|
+
self.status_data[addr] == StatusType.Running
|
388
|
+
for addr in ready_devices
|
389
|
+
):
|
390
|
+
if not self.settings.devices:
|
391
|
+
self._send_start(addr_to_hex(BROADCAST_ADDRESS))
|
414
392
|
else:
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
self.
|
393
|
+
for device_addr in self.settings.devices:
|
394
|
+
if device_addr not in ready_devices:
|
395
|
+
continue
|
396
|
+
self._send_start(device_addr)
|
397
|
+
attempts += 1
|
398
|
+
time.sleep(COMMAND_ATTEMPT_DELAY)
|
399
|
+
self._live_status(
|
400
|
+
ready_devices, timeout=COMMAND_TIMEOUT, message="to start"
|
401
|
+
)
|
422
402
|
|
423
403
|
def stop(self):
|
424
404
|
"""Stop the application."""
|
425
|
-
self.stopped_data = []
|
426
405
|
stoppable_devices = self.running_devices + self.resetting_devices
|
427
|
-
if not self.settings.devices:
|
428
|
-
self._send_stop("0")
|
429
|
-
else:
|
430
|
-
for device_id in self.settings.devices:
|
431
|
-
if device_id not in stoppable_devices:
|
432
|
-
continue
|
433
|
-
self._send_stop(device_id)
|
434
|
-
return self.stopped_data
|
435
406
|
|
436
|
-
|
407
|
+
attempts = 0
|
408
|
+
while attempts < COMMAND_MAX_ATTEMPTS and not all(
|
409
|
+
self.status_data[addr]
|
410
|
+
in [StatusType.Stopping, StatusType.Bootloader]
|
411
|
+
for addr in stoppable_devices
|
412
|
+
):
|
413
|
+
if not self.settings.devices:
|
414
|
+
self.send_payload(BROADCAST_ADDRESS, PayloadStopRequest())
|
415
|
+
else:
|
416
|
+
for device_addr in self.settings.devices:
|
417
|
+
if (
|
418
|
+
device_addr not in stoppable_devices
|
419
|
+
or self.status_data[device_addr].status
|
420
|
+
in [StatusType.Stopping, StatusType.Bootloader]
|
421
|
+
):
|
422
|
+
continue
|
423
|
+
self.send_payload(
|
424
|
+
int(device_addr, 16), PayloadStopRequest()
|
425
|
+
)
|
426
|
+
attempts += 1
|
427
|
+
time.sleep(COMMAND_ATTEMPT_DELAY)
|
428
|
+
self._live_status(
|
429
|
+
stoppable_devices, timeout=COMMAND_TIMEOUT, message="to stop"
|
430
|
+
)
|
431
|
+
|
432
|
+
def _send_reset(self, device_addr: int, location: ResetLocation):
|
437
433
|
payload = PayloadResetRequest(
|
438
|
-
device_id=int(device_id, base=16),
|
439
434
|
pos_x=location.pos_x,
|
440
435
|
pos_y=location.pos_y,
|
441
436
|
)
|
442
|
-
self.
|
437
|
+
self.send_payload(device_addr, payload)
|
443
438
|
|
444
439
|
def reset(self, locations: dict[str, ResetLocation]):
|
445
440
|
"""Reset the application."""
|
446
441
|
ready_devices = self.ready_devices
|
447
|
-
for
|
448
|
-
if
|
442
|
+
for device_addr in self.settings.devices:
|
443
|
+
if device_addr not in ready_devices:
|
449
444
|
continue
|
450
445
|
print(
|
451
|
-
f"Resetting device {
|
446
|
+
f"Resetting device {device_addr} with location {locations[device_addr]}"
|
452
447
|
)
|
453
|
-
self._send_reset(
|
448
|
+
self._send_reset(int(device_addr, 16), locations[device_addr])
|
454
449
|
|
455
450
|
def monitor(self):
|
456
451
|
"""Monitor the testbed."""
|
@@ -458,44 +453,51 @@ class Controller:
|
|
458
453
|
while True:
|
459
454
|
time.sleep(0.01)
|
460
455
|
|
461
|
-
def _send_message(self,
|
456
|
+
def _send_message(self, device_addr: int, message: str):
|
462
457
|
payload = PayloadMessage(
|
463
|
-
device_id=int(device_id, base=16),
|
464
458
|
count=len(message),
|
465
459
|
message=message.encode(),
|
466
460
|
)
|
467
|
-
|
468
|
-
self.send_frame(frame)
|
461
|
+
self.send_payload(device_addr, payload)
|
469
462
|
|
470
463
|
def send_message(self, message):
|
471
464
|
"""Send a message to the devices."""
|
472
465
|
running_devices = self.running_devices
|
473
466
|
if not self.settings.devices:
|
474
|
-
self._send_message(
|
467
|
+
self._send_message(BROADCAST_ADDRESS, message)
|
475
468
|
else:
|
476
|
-
for
|
477
|
-
if
|
469
|
+
for addr in self.settings.devices:
|
470
|
+
if addr not in running_devices:
|
478
471
|
continue
|
479
|
-
self._send_message(
|
480
|
-
|
481
|
-
def _send_start_ota(self, device_id: str, firmware: bytes):
|
472
|
+
self._send_message(int(addr, 16), message)
|
482
473
|
|
474
|
+
def _send_start_ota(
|
475
|
+
self, device_addr: str, devices_to_flash: set[str], firmware: bytes
|
476
|
+
):
|
483
477
|
def is_start_ota_acknowledged():
|
484
|
-
if
|
485
|
-
return sorted(self.start_ota_data.
|
486
|
-
|
478
|
+
if int(device_addr, 16) == BROADCAST_ADDRESS:
|
479
|
+
return sorted(self.start_ota_data.addrs) == sorted(
|
480
|
+
devices_to_flash
|
487
481
|
)
|
488
482
|
else:
|
489
|
-
return
|
483
|
+
return device_addr in self.start_ota_data.addrs
|
490
484
|
|
491
485
|
payload = PayloadOTAStartRequest(
|
492
|
-
device_id=int(device_id, base=16),
|
493
486
|
fw_length=len(firmware),
|
494
487
|
fw_chunk_count=len(self.chunks),
|
495
|
-
fw_hash=self.fw_hash,
|
496
488
|
)
|
497
|
-
|
498
|
-
|
489
|
+
send_time = time.time()
|
490
|
+
send = True
|
491
|
+
while (
|
492
|
+
not is_start_ota_acknowledged()
|
493
|
+
and self.start_ota_data.retries <= self.settings.ota_max_retries
|
494
|
+
):
|
495
|
+
if send is True:
|
496
|
+
self.send_payload(int(device_addr, 16), payload)
|
497
|
+
send_time = time.time()
|
498
|
+
self.start_ota_data.retries += 1
|
499
|
+
time.sleep(0.001)
|
500
|
+
send = time.time() - send_time > self.settings.ota_timeout
|
499
501
|
|
500
502
|
def start_ota(self, firmware) -> StartOtaData:
|
501
503
|
"""Start the OTA process."""
|
@@ -514,107 +516,149 @@ class Controller:
|
|
514
516
|
chunk_idx * CHUNK_SIZE : chunk_idx * CHUNK_SIZE + chunk_size
|
515
517
|
]
|
516
518
|
digest.update(data)
|
519
|
+
chunk_sha = hashes.Hash(hashes.SHA256())
|
520
|
+
chunk_sha.update(data)
|
517
521
|
self.chunks.append(
|
518
522
|
DataChunk(
|
519
523
|
index=chunk_idx,
|
520
524
|
size=chunk_size,
|
525
|
+
sha=chunk_sha.finalize()[
|
526
|
+
:8
|
527
|
+
], # the first 8 bytes should be enough
|
521
528
|
data=data,
|
522
529
|
)
|
523
530
|
)
|
524
|
-
self.fw_hash = digest.finalize()
|
525
|
-
self.expected_reply = (
|
526
|
-
SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
|
527
|
-
)
|
528
|
-
self.start_ota_data.fw_hash = self.fw_hash
|
531
|
+
self.start_ota_data.fw_hash = digest.finalize()
|
529
532
|
self.start_ota_data.chunks = len(self.chunks)
|
533
|
+
devices_to_flash = self.ready_devices
|
530
534
|
if not self.settings.devices:
|
531
535
|
print("Broadcast start ota notification...")
|
532
|
-
self._send_start_ota(
|
536
|
+
self._send_start_ota(
|
537
|
+
addr_to_hex(BROADCAST_ADDRESS), devices_to_flash, firmware
|
538
|
+
)
|
533
539
|
else:
|
534
|
-
for
|
535
|
-
print(f"Sending start ota notification to {
|
536
|
-
self._send_start_ota(
|
537
|
-
|
538
|
-
return
|
539
|
-
|
540
|
-
|
541
|
-
|
540
|
+
for addr in devices_to_flash:
|
541
|
+
print(f"Sending start ota notification to {addr}...")
|
542
|
+
self._send_start_ota(addr, devices_to_flash, firmware)
|
543
|
+
time.sleep(0.2)
|
544
|
+
return {
|
545
|
+
"ota": self.start_ota_data,
|
546
|
+
"acked": sorted(self.start_ota_data.addrs),
|
547
|
+
"missed": sorted(
|
548
|
+
set(devices_to_flash).difference(
|
549
|
+
set(self.start_ota_data.addrs)
|
550
|
+
)
|
551
|
+
),
|
552
|
+
}
|
553
|
+
|
554
|
+
def send_chunk(
|
555
|
+
self,
|
556
|
+
chunk: DataChunk,
|
557
|
+
device_addr: str,
|
558
|
+
devices_to_flash: set[str],
|
559
|
+
):
|
542
560
|
def is_chunk_acknowledged():
|
543
|
-
if
|
561
|
+
if int(device_addr, 16) == BROADCAST_ADDRESS:
|
544
562
|
return sorted(self.transfer_data.keys()) == sorted(
|
545
|
-
|
563
|
+
devices_to_flash
|
546
564
|
) and all(
|
547
565
|
[
|
548
|
-
chunk.index
|
566
|
+
status.chunks[chunk.index].acked
|
549
567
|
for status in self.transfer_data.values()
|
550
568
|
]
|
551
569
|
)
|
552
570
|
else:
|
553
571
|
return (
|
554
|
-
|
555
|
-
and
|
556
|
-
|
572
|
+
device_addr in self.transfer_data.keys()
|
573
|
+
and self.transfer_data[device_addr]
|
574
|
+
.chunks[chunk.index]
|
575
|
+
.acked
|
557
576
|
)
|
558
577
|
|
578
|
+
payload = PayloadOTAChunkRequest(
|
579
|
+
index=chunk.index,
|
580
|
+
count=chunk.size,
|
581
|
+
sha=chunk.sha,
|
582
|
+
chunk=chunk.data,
|
583
|
+
)
|
559
584
|
send_time = time.time()
|
560
585
|
send = True
|
561
|
-
|
562
|
-
while
|
563
|
-
|
564
|
-
|
586
|
+
retries_count = 0
|
587
|
+
while (
|
588
|
+
not is_chunk_acknowledged()
|
589
|
+
and retries_count <= self.settings.ota_max_retries
|
590
|
+
):
|
565
591
|
if send is True:
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
592
|
+
if self.settings.verbose:
|
593
|
+
missing_acks = [
|
594
|
+
addr
|
595
|
+
for addr in devices_to_flash
|
596
|
+
if addr not in self.transfer_data
|
597
|
+
or not self.transfer_data[addr]
|
598
|
+
.chunks[chunk.index]
|
599
|
+
.acked
|
600
|
+
]
|
601
|
+
print(
|
602
|
+
f"Transferring chunk {chunk.index}/{len(self.start_ota_data.chunks)} to {device_addr} "
|
603
|
+
f"- {retries_count} retries "
|
604
|
+
f"- {len(missing_acks)} missing acks: {', '.join(missing_acks) if missing_acks else 'none'}"
|
605
|
+
)
|
606
|
+
self.send_payload(int(device_addr, 16), payload)
|
607
|
+
if int(device_addr, 16) == BROADCAST_ADDRESS:
|
608
|
+
for addr in devices_to_flash:
|
609
|
+
self.transfer_data[addr].chunks[
|
576
610
|
chunk.index
|
577
|
-
] =
|
611
|
+
].retries = retries_count
|
578
612
|
else:
|
579
|
-
self.transfer_data[
|
580
|
-
|
581
|
-
|
613
|
+
self.transfer_data[device_addr].chunks[
|
614
|
+
chunk.index
|
615
|
+
].retries = retries_count
|
582
616
|
send_time = time.time()
|
617
|
+
retries_count += 1
|
583
618
|
time.sleep(0.001)
|
584
|
-
send = time.time() - send_time >
|
619
|
+
send = time.time() - send_time > self.settings.ota_timeout
|
585
620
|
|
586
|
-
def transfer(self, firmware):
|
621
|
+
def transfer(self, firmware, devices) -> dict[str, TransferDataStatus]:
|
587
622
|
"""Transfer the firmware to the devices."""
|
588
623
|
data_size = len(firmware)
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
)
|
624
|
+
use_progress_bar = not self.settings.verbose
|
625
|
+
if use_progress_bar:
|
626
|
+
progress = tqdm(
|
627
|
+
range(0, data_size),
|
628
|
+
unit="B",
|
629
|
+
unit_scale=False,
|
630
|
+
colour="green",
|
631
|
+
ncols=100,
|
632
|
+
)
|
633
|
+
progress.set_description(
|
634
|
+
f"Loading firmware ({int(data_size / 1024)}kB)"
|
635
|
+
)
|
602
636
|
self.transfer_data = {}
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
self.
|
607
|
-
|
608
|
-
|
609
|
-
self.transfer_data[device_id] = TransferDataStatus()
|
610
|
-
self.transfer_data[device_id].retries = [0] * len(self.chunks)
|
637
|
+
for device_addr in devices:
|
638
|
+
self.transfer_data[device_addr] = TransferDataStatus()
|
639
|
+
self.transfer_data[device_addr].chunks = [
|
640
|
+
Chunk(index=f"{i:03d}", size=f"{self.chunks[i].size:03d}B")
|
641
|
+
for i in range(len(self.chunks))
|
642
|
+
]
|
611
643
|
for chunk in self.chunks:
|
612
644
|
if not self.settings.devices:
|
613
|
-
self.send_chunk(
|
645
|
+
self.send_chunk(
|
646
|
+
chunk,
|
647
|
+
addr_to_hex(BROADCAST_ADDRESS),
|
648
|
+
devices,
|
649
|
+
)
|
614
650
|
else:
|
615
|
-
for
|
616
|
-
self.send_chunk(chunk,
|
617
|
-
|
618
|
-
|
619
|
-
|
651
|
+
for addr in devices:
|
652
|
+
self.send_chunk(chunk, addr, devices)
|
653
|
+
if use_progress_bar:
|
654
|
+
progress.update(chunk.size)
|
655
|
+
if use_progress_bar:
|
656
|
+
progress.close()
|
657
|
+
for device in devices:
|
658
|
+
device_data = self.transfer_data.get(device)
|
659
|
+
if device_data:
|
660
|
+
device_data.success = all(
|
661
|
+
chunk.acked for chunk in device_data.chunks
|
662
|
+
)
|
663
|
+
self.transfer_data[device] = device_data
|
620
664
|
return self.transfer_data
|