swarmit 0.3.0__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,24 +2,24 @@
2
2
 
3
3
  import dataclasses
4
4
  import time
5
+ from binascii import hexlify
5
6
  from dataclasses import dataclass
6
- from typing import Optional
7
7
 
8
- import serial
9
8
  from cryptography.hazmat.primitives import hashes
10
9
  from dotbot.logger import LOGGER
11
- from dotbot.protocol import Frame, Header
12
- from dotbot.serial_interface import SerialInterfaceException, get_default_port
10
+ from dotbot.protocol import Packet, Payload
11
+ from dotbot.serial_interface import get_default_port
13
12
  from rich import print
14
- from rich.console import Console
13
+ from rich.console import Group
15
14
  from rich.live import Live
16
15
  from rich.table import Table
16
+ from rich.text import Text
17
17
  from tqdm import tqdm
18
18
 
19
19
  from testbed.swarmit.adapter import (
20
20
  GatewayAdapterBase,
21
- MQTTAdapter,
22
- SerialAdapter,
21
+ MarilibCloudAdapter,
22
+ MarilibEdgeAdapter,
23
23
  )
24
24
  from testbed.swarmit.protocol import (
25
25
  PayloadMessage,
@@ -27,15 +27,21 @@ from testbed.swarmit.protocol import (
27
27
  PayloadOTAStartRequest,
28
28
  PayloadResetRequest,
29
29
  PayloadStartRequest,
30
- PayloadStatusRequest,
31
30
  PayloadStopRequest,
32
31
  StatusType,
33
32
  SwarmitPayloadType,
34
33
  register_parsers,
35
34
  )
36
35
 
37
- CHUNK_SIZE = 128
36
+ CHUNK_SIZE = 64
37
+ COMMAND_TIMEOUT = 6
38
+ COMMAND_MAX_ATTEMPTS = 5
39
+ COMMAND_ATTEMPT_DELAY = 1
40
+ STATUS_TIMEOUT = 5
41
+ OTA_MAX_RETRIES_DEFAULT = 10
42
+ OTA_ACK_TIMEOUT_DEFAULT = 3
38
43
  SERIAL_PORT_DEFAULT = get_default_port()
44
+ BROADCAST_ADDRESS = 0xFFFFFFFFFFFFFFFF
39
45
 
40
46
 
41
47
  @dataclass
@@ -44,6 +50,7 @@ class DataChunk:
44
50
 
45
51
  index: int
46
52
  size: int
53
+ sha: bytes
47
54
  data: bytes
48
55
 
49
56
 
@@ -53,16 +60,29 @@ class StartOtaData:
53
60
 
54
61
  chunks: int = 0
55
62
  fw_hash: bytes = b""
56
- ids: list[str] = dataclasses.field(default_factory=lambda: [])
63
+ addrs: list[str] = dataclasses.field(default_factory=lambda: [])
64
+ retries: int = 0
65
+
66
+
67
+ @dataclass
68
+ class Chunk:
69
+ """Class that holds chunk status."""
70
+
71
+ index: str = "0"
72
+ size: str = "0B"
73
+ acked: int = 0
74
+ retries: int = 0
75
+
76
+ def __repr__(self):
77
+ return f"{dataclasses.asdict(self)}"
57
78
 
58
79
 
59
80
  @dataclass
60
81
  class TransferDataStatus:
61
82
  """Class that holds transfer data status for a single device."""
62
83
 
63
- retries: list[int] = dataclasses.field(default_factory=lambda: [])
64
- chunks_acked: set[int] = dataclasses.field(default_factory=lambda: set())
65
- hashes_match: bool = False
84
+ chunks: list[Chunk] = dataclasses.field(default_factory=lambda: [])
85
+ success: bool = False
66
86
 
67
87
 
68
88
  @dataclass
@@ -76,57 +96,38 @@ class ResetLocation:
76
96
  return f"(x={self.pos_x}, y={self.pos_y})"
77
97
 
78
98
 
79
- def print_status(status_data: dict[str, StatusType]) -> None:
80
- """Print the status of the devices."""
81
- print()
82
- print(
83
- f"{len(status_data)} device{'s' if len(status_data) > 1 else ''} found"
84
- )
85
- print()
86
- status_table = Table()
87
- status_table.add_column("Device ID", style="magenta", no_wrap=True)
88
- status_table.add_column("Status", style="green", justify="center")
89
- with Live(status_table, refresh_per_second=4) as live:
90
- live.update(status_table)
91
- for device_id, status in sorted(status_data.items()):
92
- status_table.add_row(
93
- f"{device_id}",
94
- f'{"[bold cyan]" if status == StatusType.Running else "[bold green]"}{status.name}',
95
- )
99
+ def addr_to_hex(addr: int) -> str:
100
+ """Convert an address to its hexadecimal representation."""
101
+ return hexlify(addr.to_bytes(8, "big")).decode().upper()
96
102
 
97
103
 
98
- def print_start_status(
99
- stopped_data: list[str], not_started: list[str]
100
- ) -> None:
101
- """Print the start status."""
102
- print("[bold]Start status:[/]")
103
- status_table = Table()
104
- status_table.add_column("Device ID", style="magenta", no_wrap=True)
105
- status_table.add_column("Status", style="green", justify="center")
106
- with Live(status_table, refresh_per_second=4) as live:
107
- live.update(status_table)
108
- for device_id in sorted(stopped_data):
109
- status_table.add_row(
110
- f"{device_id}", "[bold green]:heavy_check_mark:[/]"
111
- )
112
- for device_id in sorted(not_started):
113
- status_table.add_row(f"{device_id}", "[bold red]:x:[/]")
114
-
115
-
116
- def print_stop_status(stopped_data: list[str], not_stopped: list[str]) -> None:
117
- """Print the stop status."""
118
- print("[bold]Stop status:[/]")
119
- status_table = Table()
120
- status_table.add_column("Device ID", style="magenta", no_wrap=True)
121
- status_table.add_column("Status", style="green", justify="center")
122
- with Live(status_table, refresh_per_second=4) as live:
123
- live.update(status_table)
124
- for device_id in sorted(stopped_data):
125
- status_table.add_row(
126
- f"{device_id}", "[bold green]:heavy_check_mark:[/]"
127
- )
128
- for device_id in sorted(not_stopped):
129
- status_table.add_row(f"{device_id}", "[bold red]:x:[/]")
104
+ def generate_status(status_data, devices=[], status_message="found"):
105
+ data = {
106
+ key: device
107
+ for key, device in status_data.items()
108
+ if (devices and key in devices) or (not devices)
109
+ }
110
+ if not data:
111
+ return Group(Text(f"\nNo device {status_message}\n"))
112
+
113
+ header = Text(
114
+ f"\n{len(data)} device{'s' if len(data) > 1 else ''} {status_message}\n"
115
+ )
116
+
117
+ table = Table()
118
+ table.add_column("Device Addr", style="magenta", no_wrap=True)
119
+ table.add_column(
120
+ "Status",
121
+ style="green",
122
+ justify="center",
123
+ width=max([len(m) for m in StatusType.__members__]),
124
+ )
125
+ for device_addr, status in sorted(data.items()):
126
+ table.add_row(
127
+ f"{device_addr}",
128
+ f"{'[bold cyan]' if status == StatusType.Running else '[bold green]'}{status.name}",
129
+ )
130
+ return Group(header, table)
130
131
 
131
132
 
132
133
  def print_transfer_status(
@@ -137,26 +138,19 @@ def print_transfer_status(
137
138
  print("[bold]Transfer status:[/]")
138
139
  transfer_status_table = Table()
139
140
  transfer_status_table.add_column(
140
- "Device ID", style="magenta", no_wrap=True
141
+ "Device Addr", style="magenta", no_wrap=True
141
142
  )
142
143
  transfer_status_table.add_column(
143
144
  "Chunks acked", style="green", justify="center"
144
145
  )
145
- transfer_status_table.add_column(
146
- "Hashes match", style="green", justify="center"
147
- )
146
+
148
147
  with Live(transfer_status_table, refresh_per_second=4) as live:
149
148
  live.update(transfer_status_table)
150
- for device_id, status in sorted(status.items()):
151
- start_marker, stop_marker = (
152
- ("[bold green]", "[/]")
153
- if bool(status.hashes_match) is True
154
- else ("[bold red]", "[/]")
155
- )
149
+ for device_addr, status in sorted(status.items()):
150
+ chunks_col_color = "[green]" if status.success else "[bold red]"
156
151
  transfer_status_table.add_row(
157
- f"{device_id}",
158
- f"{len(status.chunks_acked)}/{start_data.chunks}",
159
- f"{start_marker}{bool(status.hashes_match)}{stop_marker}",
152
+ f"{device_addr}",
153
+ f"{chunks_col_color}{len([chunk for chunk in status.chunks if bool(chunk.acked)])}/{start_data.chunks}",
160
154
  )
161
155
 
162
156
 
@@ -176,10 +170,15 @@ class ControllerSettings:
176
170
 
177
171
  serial_port: str = SERIAL_PORT_DEFAULT
178
172
  serial_baudrate: int = 1000000
179
- mqtt_host: str = "argus.paris.inria.fr"
180
- mqtt_port: int = 8883
181
- edge: bool = False
173
+ mqtt_host: str = "localhost"
174
+ mqtt_port: int = 1883
175
+ mqtt_use_tls: bool = False
176
+ network_id: int = 1
177
+ adapter: str = "serial" # or "mqtt", "marilib-edge", "marilib-cloud"
182
178
  devices: list[str] = dataclasses.field(default_factory=lambda: [])
179
+ ota_max_retries: int = OTA_MAX_RETRIES_DEFAULT
180
+ ota_timeout: float = OTA_ACK_TIMEOUT_DEFAULT
181
+ verbose: bool = False
183
182
 
184
183
 
185
184
  class Controller:
@@ -196,43 +195,44 @@ class Controller:
196
195
  self.start_ota_data: StartOtaData = StartOtaData()
197
196
  self.transfer_data: dict[str, TransferDataStatus] = {}
198
197
  self._known_devices: dict[str, StatusType] = {}
199
- self.expected_reply: Optional[SwarmitPayloadType] = None
200
198
  register_parsers()
201
- if self.settings.edge is True:
202
- self._interface = MQTTAdapter(
203
- self.settings.mqtt_host, self.settings.mqtt_port
199
+ if self.settings.adapter == "cloud":
200
+ self._interface = MarilibCloudAdapter(
201
+ self.settings.mqtt_host,
202
+ self.settings.mqtt_port,
203
+ self.settings.mqtt_use_tls,
204
+ self.settings.network_id,
205
+ verbose=self.settings.verbose,
204
206
  )
205
207
  else:
206
- try:
207
- self._interface = SerialAdapter(
208
- self.settings.serial_port, self.settings.serial_baudrate
209
- )
210
- except (
211
- SerialInterfaceException,
212
- serial.serialutil.SerialException,
213
- ) as exc:
214
- console = Console()
215
- console.print(f"[bold red]Error:[/] {exc}")
216
- self._interface.init(self.on_data_received)
208
+ self._interface = MarilibEdgeAdapter(
209
+ self.settings.serial_port,
210
+ self.settings.serial_baudrate,
211
+ verbose=self.settings.verbose,
212
+ )
213
+ self._interface.init(self.on_frame_received)
217
214
 
218
215
  @property
219
216
  def known_devices(self) -> dict[str, StatusType]:
220
217
  """Return the known devices."""
221
218
  if not self._known_devices:
222
- self._known_devices = self.status()
219
+ wait_for_done(COMMAND_TIMEOUT, lambda: False)
220
+ self._known_devices = self.status_data
223
221
  return self._known_devices
224
222
 
225
223
  @property
226
224
  def running_devices(self) -> list[str]:
227
225
  """Return the running devices."""
228
226
  return [
229
- device_id
230
- for device_id, status in self.known_devices.items()
227
+ addr
228
+ for addr, status in self.known_devices.items()
231
229
  if (
232
- status == StatusType.Running
230
+ (
231
+ status == StatusType.Running
232
+ or status == StatusType.Programming
233
+ )
233
234
  and (
234
- not self.settings.devices
235
- or device_id in self.settings.devices
235
+ not self.settings.devices or addr in self.settings.devices
236
236
  )
237
237
  )
238
238
  ]
@@ -241,13 +241,13 @@ class Controller:
241
241
  def resetting_devices(self) -> list[str]:
242
242
  """Return the resetting devices."""
243
243
  return [
244
- device_id
245
- for device_id, status in self.known_devices.items()
244
+ device_addr
245
+ for device_addr, status in self.known_devices.items()
246
246
  if (
247
247
  status == StatusType.Resetting
248
248
  and (
249
249
  not self.settings.devices
250
- or device_id in self.settings.devices
250
+ or device_addr in self.settings.devices
251
251
  )
252
252
  )
253
253
  ]
@@ -256,13 +256,13 @@ class Controller:
256
256
  def ready_devices(self) -> list[str]:
257
257
  """Return the ready devices."""
258
258
  return [
259
- device_id
260
- for device_id, status in self.known_devices.items()
259
+ device_addr
260
+ for device_addr, status in self.known_devices.items()
261
261
  if (
262
262
  status == StatusType.Bootloader
263
263
  and (
264
264
  not self.settings.devices
265
- or device_id in self.settings.devices
265
+ or device_addr in self.settings.devices
266
266
  )
267
267
  )
268
268
  ]
@@ -276,181 +276,176 @@ class Controller:
276
276
  """Terminate the controller."""
277
277
  self.interface.close()
278
278
 
279
- def send_frame(self, frame: Frame):
279
+ def send_payload(self, destination: int, payload: Payload):
280
280
  """Send a frame to the devices."""
281
- self.interface.send_data(frame.to_bytes())
282
-
283
- def on_data_received(self, data):
284
- frame = Frame().from_bytes(data)
285
- if frame.payload_type < SwarmitPayloadType.SWARMIT_REQUEST_STATUS:
281
+ self.interface.send_payload(destination, payload)
282
+
283
+ def on_frame_received(self, header, packet: Packet):
284
+ """Handle the received frame."""
285
+ # if self.settings.verbose:
286
+ # print()
287
+ # print(Frame(header, packet))
288
+ if packet.payload_type < SwarmitPayloadType.SWARMIT_REQUEST_STATUS:
286
289
  return
287
- device_id = f"{frame.payload.device_id:08X}"
290
+ device_addr = f"{header.source:08X}"
288
291
  if (
289
- frame.payload_type
290
- == SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
291
- and self.expected_reply
292
+ packet.payload_type
292
293
  == SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
293
294
  ):
294
295
  self.status_data.update(
295
- {device_id: StatusType(frame.payload.status)}
296
+ {device_addr: StatusType(packet.payload.status)}
296
297
  )
297
298
  elif (
298
- frame.payload_type
299
- == SwarmitPayloadType.SWARMIT_NOTIFICATION_STARTED
300
- and self.expected_reply
301
- == SwarmitPayloadType.SWARMIT_NOTIFICATION_STARTED
302
- ):
303
- if device_id not in self.started_data:
304
- self.started_data.append(device_id)
305
- elif (
306
- frame.payload_type
307
- == SwarmitPayloadType.SWARMIT_NOTIFICATION_STOPPED
308
- and self.expected_reply
309
- == SwarmitPayloadType.SWARMIT_NOTIFICATION_STOPPED
310
- ):
311
- if device_id not in self.stopped_data:
312
- self.stopped_data.append(device_id)
313
- elif (
314
- frame.payload_type
315
- == SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
316
- and self.expected_reply
299
+ packet.payload_type
317
300
  == SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
318
301
  ):
319
- if device_id not in self.start_ota_data.ids:
320
- self.start_ota_data.ids.append(device_id)
302
+ if device_addr in self.start_ota_data.addrs:
303
+ return
304
+ self.start_ota_data.addrs.append(device_addr)
321
305
  elif (
322
- frame.payload_type
306
+ packet.payload_type
323
307
  == SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_CHUNK_ACK
324
308
  ):
325
- if (
326
- frame.payload.index
327
- not in self.transfer_data[device_id].chunks_acked
328
- ):
329
- self.transfer_data[device_id].chunks_acked.add(
330
- frame.payload.index
309
+ try:
310
+ acked = bool(
311
+ self.transfer_data[device_addr]
312
+ .chunks[packet.payload.index]
313
+ .acked
331
314
  )
332
- self.transfer_data[device_id].hashes_match = (
333
- frame.payload.hashes_match
334
- )
335
- elif frame.payload_type in [
315
+ except (IndexError, KeyError):
316
+ self.logger.warning(
317
+ "Chunk index out of range",
318
+ device_addr=device_addr,
319
+ chunk_index=packet.payload.index,
320
+ )
321
+ return
322
+ if acked is False:
323
+ self.transfer_data[device_addr].chunks[
324
+ packet.payload.index
325
+ ].acked = 1
326
+ elif packet.payload_type in [
336
327
  SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_GPIO,
337
328
  SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_LOG,
338
329
  ]:
339
330
  if (
340
331
  self.settings.devices
341
- and device_id not in self.settings.devices
332
+ and device_addr not in self.settings.devices
342
333
  ):
343
334
  return
344
335
  logger = self.logger.bind(
345
- deviceid=device_id,
346
- notification=frame.payload_type.name,
347
- timestamp=frame.payload.timestamp,
348
- data_size=frame.payload.count,
349
- data=frame.payload.data,
336
+ device_addr=device_addr,
337
+ notification=SwarmitPayloadType(packet.payload_type).name,
338
+ timestamp=packet.payload.timestamp,
339
+ data_size=packet.payload.count,
340
+ data=packet.payload.data,
350
341
  )
351
342
  if (
352
- frame.payload_type
343
+ packet.payload_type
353
344
  == SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_GPIO
354
345
  ):
355
346
  logger.info("GPIO event")
356
347
  elif (
357
- frame.payload_type
348
+ packet.payload_type
358
349
  == SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_LOG
359
350
  ):
360
351
  logger.info("LOG event")
361
- elif frame.payload_type != self.expected_reply:
362
- self.logger.warning(
363
- "Unexpected payload",
364
- payload_type=hex(frame.payload_type),
365
- expected=hex(self.expected_reply),
366
- )
367
352
  else:
368
353
  self.logger.error(
369
- "Unknown payload type", payload_type=frame.payload_type
354
+ "Unknown payload type", payload_type=packet.payload_type
370
355
  )
371
356
 
357
+ def _live_status(
358
+ self, devices=[], timeout=STATUS_TIMEOUT, message="found"
359
+ ):
360
+ """Request the live status of the testbed."""
361
+ with Live(
362
+ generate_status(self.status_data, devices, status_message=message),
363
+ refresh_per_second=4,
364
+ ) as live:
365
+ while timeout > 0:
366
+ live.update(
367
+ generate_status(
368
+ self.status_data, devices, status_message=message
369
+ )
370
+ )
371
+ timeout -= 0.01
372
+ time.sleep(0.01)
373
+
372
374
  def status(self):
373
375
  """Request the status of the testbed."""
374
- self.status_data: dict[str, StatusType] = {}
375
- payload = PayloadStatusRequest(device_id=0)
376
- frame = Frame(header=Header(), payload=payload)
377
- self.expected_reply = SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
378
- self.send_frame(frame)
379
- wait_for_done(1, lambda: False)
380
- return self.status_data
381
-
382
- def _send_start(self, device_id: str):
383
- def is_started():
384
- if device_id == "0":
385
- return sorted(self.started_data) == sorted(self.ready_devices)
386
- else:
387
- return device_id in self.started_data
376
+ self._live_status(self.settings.devices)
388
377
 
389
- self.expected_reply = SwarmitPayloadType.SWARMIT_NOTIFICATION_STARTED
390
- payload = PayloadStartRequest(device_id=int(device_id, base=16))
391
- self.send_frame(Frame(header=Header(), payload=payload))
392
- wait_for_done(3, is_started)
393
- self.expected_reply = None
378
+ def _send_start(self, device_addr: str):
379
+ payload = PayloadStartRequest()
380
+ self.send_payload(int(device_addr, 16), payload)
394
381
 
395
382
  def start(self):
396
383
  """Start the application."""
397
- self.started_data = []
398
384
  ready_devices = self.ready_devices
399
- if not self.settings.devices:
400
- self._send_start("0")
401
- else:
402
- for device_id in self.settings.devices:
403
- if device_id not in ready_devices:
404
- continue
405
- self._send_start(device_id)
406
- return self.started_data
407
-
408
- def _send_stop(self, device_id: str):
409
- stoppable_devices = self.running_devices + self.resetting_devices
410
-
411
- def is_stopped():
412
- if device_id == "0":
413
- return sorted(self.stopped_data) == sorted(stoppable_devices)
385
+ attempts = 0
386
+ while attempts < COMMAND_MAX_ATTEMPTS and not all(
387
+ self.status_data[addr] == StatusType.Running
388
+ for addr in ready_devices
389
+ ):
390
+ if not self.settings.devices:
391
+ self._send_start(addr_to_hex(BROADCAST_ADDRESS))
414
392
  else:
415
- return device_id in self.stopped_data
416
-
417
- self.expected_reply = SwarmitPayloadType.SWARMIT_NOTIFICATION_STOPPED
418
- payload = PayloadStopRequest(device_id=int(device_id, base=16))
419
- self.send_frame(Frame(header=Header(), payload=payload))
420
- wait_for_done(3, is_stopped)
421
- self.expected_reply = None
393
+ for device_addr in self.settings.devices:
394
+ if device_addr not in ready_devices:
395
+ continue
396
+ self._send_start(device_addr)
397
+ attempts += 1
398
+ time.sleep(COMMAND_ATTEMPT_DELAY)
399
+ self._live_status(
400
+ ready_devices, timeout=COMMAND_TIMEOUT, message="to start"
401
+ )
422
402
 
423
403
  def stop(self):
424
404
  """Stop the application."""
425
- self.stopped_data = []
426
405
  stoppable_devices = self.running_devices + self.resetting_devices
427
- if not self.settings.devices:
428
- self._send_stop("0")
429
- else:
430
- for device_id in self.settings.devices:
431
- if device_id not in stoppable_devices:
432
- continue
433
- self._send_stop(device_id)
434
- return self.stopped_data
435
406
 
436
- def _send_reset(self, device_id: str, location: ResetLocation):
407
+ attempts = 0
408
+ while attempts < COMMAND_MAX_ATTEMPTS and not all(
409
+ self.status_data[addr]
410
+ in [StatusType.Stopping, StatusType.Bootloader]
411
+ for addr in stoppable_devices
412
+ ):
413
+ if not self.settings.devices:
414
+ self.send_payload(BROADCAST_ADDRESS, PayloadStopRequest())
415
+ else:
416
+ for device_addr in self.settings.devices:
417
+ if (
418
+ device_addr not in stoppable_devices
419
+ or self.status_data[device_addr].status
420
+ in [StatusType.Stopping, StatusType.Bootloader]
421
+ ):
422
+ continue
423
+ self.send_payload(
424
+ int(device_addr, 16), PayloadStopRequest()
425
+ )
426
+ attempts += 1
427
+ time.sleep(COMMAND_ATTEMPT_DELAY)
428
+ self._live_status(
429
+ stoppable_devices, timeout=COMMAND_TIMEOUT, message="to stop"
430
+ )
431
+
432
+ def _send_reset(self, device_addr: int, location: ResetLocation):
437
433
  payload = PayloadResetRequest(
438
- device_id=int(device_id, base=16),
439
434
  pos_x=location.pos_x,
440
435
  pos_y=location.pos_y,
441
436
  )
442
- self.send_frame(Frame(header=Header(), payload=payload))
437
+ self.send_payload(device_addr, payload)
443
438
 
444
439
  def reset(self, locations: dict[str, ResetLocation]):
445
440
  """Reset the application."""
446
441
  ready_devices = self.ready_devices
447
- for device_id in self.settings.devices:
448
- if device_id not in ready_devices:
442
+ for device_addr in self.settings.devices:
443
+ if device_addr not in ready_devices:
449
444
  continue
450
445
  print(
451
- f"Resetting device {device_id} with location {locations[device_id]}"
446
+ f"Resetting device {device_addr} with location {locations[device_addr]}"
452
447
  )
453
- self._send_reset(device_id, locations[device_id])
448
+ self._send_reset(int(device_addr, 16), locations[device_addr])
454
449
 
455
450
  def monitor(self):
456
451
  """Monitor the testbed."""
@@ -458,44 +453,51 @@ class Controller:
458
453
  while True:
459
454
  time.sleep(0.01)
460
455
 
461
- def _send_message(self, device_id, message):
456
+ def _send_message(self, device_addr: int, message: str):
462
457
  payload = PayloadMessage(
463
- device_id=int(device_id, base=16),
464
458
  count=len(message),
465
459
  message=message.encode(),
466
460
  )
467
- frame = Frame(header=Header(), payload=payload)
468
- self.send_frame(frame)
461
+ self.send_payload(device_addr, payload)
469
462
 
470
463
  def send_message(self, message):
471
464
  """Send a message to the devices."""
472
465
  running_devices = self.running_devices
473
466
  if not self.settings.devices:
474
- self._send_message("0", message)
467
+ self._send_message(BROADCAST_ADDRESS, message)
475
468
  else:
476
- for device_id in self.settings.devices:
477
- if device_id not in running_devices:
469
+ for addr in self.settings.devices:
470
+ if addr not in running_devices:
478
471
  continue
479
- self._send_message(device_id, message)
480
-
481
- def _send_start_ota(self, device_id: str, firmware: bytes):
472
+ self._send_message(int(addr, 16), message)
482
473
 
474
+ def _send_start_ota(
475
+ self, device_addr: str, devices_to_flash: set[str], firmware: bytes
476
+ ):
483
477
  def is_start_ota_acknowledged():
484
- if device_id == "0":
485
- return sorted(self.start_ota_data.ids) == sorted(
486
- self.ready_devices
478
+ if int(device_addr, 16) == BROADCAST_ADDRESS:
479
+ return sorted(self.start_ota_data.addrs) == sorted(
480
+ devices_to_flash
487
481
  )
488
482
  else:
489
- return device_id in self.start_ota_data.ids
483
+ return device_addr in self.start_ota_data.addrs
490
484
 
491
485
  payload = PayloadOTAStartRequest(
492
- device_id=int(device_id, base=16),
493
486
  fw_length=len(firmware),
494
487
  fw_chunk_count=len(self.chunks),
495
- fw_hash=self.fw_hash,
496
488
  )
497
- self.send_frame(Frame(header=Header(), payload=payload))
498
- wait_for_done(3, is_start_ota_acknowledged)
489
+ send_time = time.time()
490
+ send = True
491
+ while (
492
+ not is_start_ota_acknowledged()
493
+ and self.start_ota_data.retries <= self.settings.ota_max_retries
494
+ ):
495
+ if send is True:
496
+ self.send_payload(int(device_addr, 16), payload)
497
+ send_time = time.time()
498
+ self.start_ota_data.retries += 1
499
+ time.sleep(0.001)
500
+ send = time.time() - send_time > self.settings.ota_timeout
499
501
 
500
502
  def start_ota(self, firmware) -> StartOtaData:
501
503
  """Start the OTA process."""
@@ -514,107 +516,149 @@ class Controller:
514
516
  chunk_idx * CHUNK_SIZE : chunk_idx * CHUNK_SIZE + chunk_size
515
517
  ]
516
518
  digest.update(data)
519
+ chunk_sha = hashes.Hash(hashes.SHA256())
520
+ chunk_sha.update(data)
517
521
  self.chunks.append(
518
522
  DataChunk(
519
523
  index=chunk_idx,
520
524
  size=chunk_size,
525
+ sha=chunk_sha.finalize()[
526
+ :8
527
+ ], # the first 8 bytes should be enough
521
528
  data=data,
522
529
  )
523
530
  )
524
- self.fw_hash = digest.finalize()
525
- self.expected_reply = (
526
- SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
527
- )
528
- self.start_ota_data.fw_hash = self.fw_hash
531
+ self.start_ota_data.fw_hash = digest.finalize()
529
532
  self.start_ota_data.chunks = len(self.chunks)
533
+ devices_to_flash = self.ready_devices
530
534
  if not self.settings.devices:
531
535
  print("Broadcast start ota notification...")
532
- self._send_start_ota("0", firmware)
536
+ self._send_start_ota(
537
+ addr_to_hex(BROADCAST_ADDRESS), devices_to_flash, firmware
538
+ )
533
539
  else:
534
- for device_id in self.settings.devices:
535
- print(f"Sending start ota notification to {device_id}...")
536
- self._send_start_ota(device_id, firmware)
537
- self.expected_reply = None
538
- return self.start_ota_data
539
-
540
- def send_chunk(self, chunk, device_id: str):
541
-
540
+ for addr in devices_to_flash:
541
+ print(f"Sending start ota notification to {addr}...")
542
+ self._send_start_ota(addr, devices_to_flash, firmware)
543
+ time.sleep(0.2)
544
+ return {
545
+ "ota": self.start_ota_data,
546
+ "acked": sorted(self.start_ota_data.addrs),
547
+ "missed": sorted(
548
+ set(devices_to_flash).difference(
549
+ set(self.start_ota_data.addrs)
550
+ )
551
+ ),
552
+ }
553
+
554
+ def send_chunk(
555
+ self,
556
+ chunk: DataChunk,
557
+ device_addr: str,
558
+ devices_to_flash: set[str],
559
+ ):
542
560
  def is_chunk_acknowledged():
543
- if device_id == "0":
561
+ if int(device_addr, 16) == BROADCAST_ADDRESS:
544
562
  return sorted(self.transfer_data.keys()) == sorted(
545
- self.ready_devices
563
+ devices_to_flash
546
564
  ) and all(
547
565
  [
548
- chunk.index in status.chunks_acked
566
+ status.chunks[chunk.index].acked
549
567
  for status in self.transfer_data.values()
550
568
  ]
551
569
  )
552
570
  else:
553
571
  return (
554
- device_id in self.transfer_data.keys()
555
- and chunk.index
556
- in self.transfer_data[device_id].chunks_acked
572
+ device_addr in self.transfer_data.keys()
573
+ and self.transfer_data[device_addr]
574
+ .chunks[chunk.index]
575
+ .acked
557
576
  )
558
577
 
578
+ payload = PayloadOTAChunkRequest(
579
+ index=chunk.index,
580
+ count=chunk.size,
581
+ sha=chunk.sha,
582
+ chunk=chunk.data,
583
+ )
559
584
  send_time = time.time()
560
585
  send = True
561
- tries = 0
562
- while tries < 3:
563
- if is_chunk_acknowledged():
564
- break
586
+ retries_count = 0
587
+ while (
588
+ not is_chunk_acknowledged()
589
+ and retries_count <= self.settings.ota_max_retries
590
+ ):
565
591
  if send is True:
566
- payload = PayloadOTAChunkRequest(
567
- device_id=int(device_id, base=16),
568
- index=chunk.index,
569
- count=chunk.size,
570
- chunk=chunk.data,
571
- )
572
- self.send_frame(Frame(header=Header(), payload=payload))
573
- if device_id == "0":
574
- for device_id in self.ready_devices:
575
- self.transfer_data[device_id].retries[
592
+ if self.settings.verbose:
593
+ missing_acks = [
594
+ addr
595
+ for addr in devices_to_flash
596
+ if addr not in self.transfer_data
597
+ or not self.transfer_data[addr]
598
+ .chunks[chunk.index]
599
+ .acked
600
+ ]
601
+ print(
602
+ f"Transferring chunk {chunk.index}/{len(self.start_ota_data.chunks)} to {device_addr} "
603
+ f"- {retries_count} retries "
604
+ f"- {len(missing_acks)} missing acks: {', '.join(missing_acks) if missing_acks else 'none'}"
605
+ )
606
+ self.send_payload(int(device_addr, 16), payload)
607
+ if int(device_addr, 16) == BROADCAST_ADDRESS:
608
+ for addr in devices_to_flash:
609
+ self.transfer_data[addr].chunks[
576
610
  chunk.index
577
- ] = tries
611
+ ].retries = retries_count
578
612
  else:
579
- self.transfer_data[device_id].retries[chunk.index] = tries
580
- tries += 1
581
- time.sleep(0.01)
613
+ self.transfer_data[device_addr].chunks[
614
+ chunk.index
615
+ ].retries = retries_count
582
616
  send_time = time.time()
617
+ retries_count += 1
583
618
  time.sleep(0.001)
584
- send = time.time() - send_time > 1
619
+ send = time.time() - send_time > self.settings.ota_timeout
585
620
 
586
- def transfer(self, firmware):
621
+ def transfer(self, firmware, devices) -> dict[str, TransferDataStatus]:
587
622
  """Transfer the firmware to the devices."""
588
623
  data_size = len(firmware)
589
- progress = tqdm(
590
- range(0, data_size),
591
- unit="B",
592
- unit_scale=False,
593
- colour="green",
594
- ncols=100,
595
- )
596
- progress.set_description(
597
- f"Loading firmware ({int(data_size / 1024)}kB)"
598
- )
599
- self.expected_reply = (
600
- SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_CHUNK_ACK
601
- )
624
+ use_progress_bar = not self.settings.verbose
625
+ if use_progress_bar:
626
+ progress = tqdm(
627
+ range(0, data_size),
628
+ unit="B",
629
+ unit_scale=False,
630
+ colour="green",
631
+ ncols=100,
632
+ )
633
+ progress.set_description(
634
+ f"Loading firmware ({int(data_size / 1024)}kB)"
635
+ )
602
636
  self.transfer_data = {}
603
- if not self.settings.devices:
604
- for device_id in self.ready_devices:
605
- self.transfer_data[device_id] = TransferDataStatus()
606
- self.transfer_data[device_id].retries = [0] * len(self.chunks)
607
- else:
608
- for device_id in self.settings.devices:
609
- self.transfer_data[device_id] = TransferDataStatus()
610
- self.transfer_data[device_id].retries = [0] * len(self.chunks)
637
+ for device_addr in devices:
638
+ self.transfer_data[device_addr] = TransferDataStatus()
639
+ self.transfer_data[device_addr].chunks = [
640
+ Chunk(index=f"{i:03d}", size=f"{self.chunks[i].size:03d}B")
641
+ for i in range(len(self.chunks))
642
+ ]
611
643
  for chunk in self.chunks:
612
644
  if not self.settings.devices:
613
- self.send_chunk(chunk, "0")
645
+ self.send_chunk(
646
+ chunk,
647
+ addr_to_hex(BROADCAST_ADDRESS),
648
+ devices,
649
+ )
614
650
  else:
615
- for device_id in self.settings.devices:
616
- self.send_chunk(chunk, device_id)
617
- progress.update(chunk.size)
618
- progress.close()
619
- self.expected_reply = None
651
+ for addr in devices:
652
+ self.send_chunk(chunk, addr, devices)
653
+ if use_progress_bar:
654
+ progress.update(chunk.size)
655
+ if use_progress_bar:
656
+ progress.close()
657
+ for device in devices:
658
+ device_data = self.transfer_data.get(device)
659
+ if device_data:
660
+ device_data.success = all(
661
+ chunk.acked for chunk in device_data.chunks
662
+ )
663
+ self.transfer_data[device] = device_data
620
664
  return self.transfer_data