swarmit 0.2.0__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,664 @@
1
+ """Module containing the swarmit controller class."""
2
+
3
+ import dataclasses
4
+ import time
5
+ from binascii import hexlify
6
+ from dataclasses import dataclass
7
+
8
+ from cryptography.hazmat.primitives import hashes
9
+ from dotbot.logger import LOGGER
10
+ from dotbot.protocol import Packet, Payload
11
+ from dotbot.serial_interface import get_default_port
12
+ from rich import print
13
+ from rich.console import Group
14
+ from rich.live import Live
15
+ from rich.table import Table
16
+ from rich.text import Text
17
+ from tqdm import tqdm
18
+
19
+ from testbed.swarmit.adapter import (
20
+ GatewayAdapterBase,
21
+ MarilibCloudAdapter,
22
+ MarilibEdgeAdapter,
23
+ )
24
+ from testbed.swarmit.protocol import (
25
+ PayloadMessage,
26
+ PayloadOTAChunkRequest,
27
+ PayloadOTAStartRequest,
28
+ PayloadResetRequest,
29
+ PayloadStartRequest,
30
+ PayloadStopRequest,
31
+ StatusType,
32
+ SwarmitPayloadType,
33
+ register_parsers,
34
+ )
35
+
36
+ CHUNK_SIZE = 64
37
+ COMMAND_TIMEOUT = 6
38
+ COMMAND_MAX_ATTEMPTS = 5
39
+ COMMAND_ATTEMPT_DELAY = 1
40
+ STATUS_TIMEOUT = 5
41
+ OTA_MAX_RETRIES_DEFAULT = 10
42
+ OTA_ACK_TIMEOUT_DEFAULT = 3
43
+ SERIAL_PORT_DEFAULT = get_default_port()
44
+ BROADCAST_ADDRESS = 0xFFFFFFFFFFFFFFFF
45
+
46
+
47
+ @dataclass
48
+ class DataChunk:
49
+ """Class that holds data chunks."""
50
+
51
+ index: int
52
+ size: int
53
+ sha: bytes
54
+ data: bytes
55
+
56
+
57
+ @dataclass
58
+ class StartOtaData:
59
+ """Class that holds start ota data."""
60
+
61
+ chunks: int = 0
62
+ fw_hash: bytes = b""
63
+ addrs: list[str] = dataclasses.field(default_factory=lambda: [])
64
+ retries: int = 0
65
+
66
+
67
+ @dataclass
68
+ class Chunk:
69
+ """Class that holds chunk status."""
70
+
71
+ index: str = "0"
72
+ size: str = "0B"
73
+ acked: int = 0
74
+ retries: int = 0
75
+
76
+ def __repr__(self):
77
+ return f"{dataclasses.asdict(self)}"
78
+
79
+
80
+ @dataclass
81
+ class TransferDataStatus:
82
+ """Class that holds transfer data status for a single device."""
83
+
84
+ chunks: list[Chunk] = dataclasses.field(default_factory=lambda: [])
85
+ success: bool = False
86
+
87
+
88
+ @dataclass
89
+ class ResetLocation:
90
+ """Class that holds reset location."""
91
+
92
+ pos_x: int = 0
93
+ pos_y: int = 0
94
+
95
+ def __repr__(self):
96
+ return f"(x={self.pos_x}, y={self.pos_y})"
97
+
98
+
99
+ def addr_to_hex(addr: int) -> str:
100
+ """Convert an address to its hexadecimal representation."""
101
+ return hexlify(addr.to_bytes(8, "big")).decode().upper()
102
+
103
+
104
+ def generate_status(status_data, devices=[], status_message="found"):
105
+ data = {
106
+ key: device
107
+ for key, device in status_data.items()
108
+ if (devices and key in devices) or (not devices)
109
+ }
110
+ if not data:
111
+ return Group(Text(f"\nNo device {status_message}\n"))
112
+
113
+ header = Text(
114
+ f"\n{len(data)} device{'s' if len(data) > 1 else ''} {status_message}\n"
115
+ )
116
+
117
+ table = Table()
118
+ table.add_column("Device Addr", style="magenta", no_wrap=True)
119
+ table.add_column(
120
+ "Status",
121
+ style="green",
122
+ justify="center",
123
+ width=max([len(m) for m in StatusType.__members__]),
124
+ )
125
+ for device_addr, status in sorted(data.items()):
126
+ table.add_row(
127
+ f"{device_addr}",
128
+ f"{'[bold cyan]' if status == StatusType.Running else '[bold green]'}{status.name}",
129
+ )
130
+ return Group(header, table)
131
+
132
+
133
+ def print_transfer_status(
134
+ status: dict[str, TransferDataStatus], start_data: int
135
+ ) -> None:
136
+ """Print the transfer status."""
137
+ print()
138
+ print("[bold]Transfer status:[/]")
139
+ transfer_status_table = Table()
140
+ transfer_status_table.add_column(
141
+ "Device Addr", style="magenta", no_wrap=True
142
+ )
143
+ transfer_status_table.add_column(
144
+ "Chunks acked", style="green", justify="center"
145
+ )
146
+
147
+ with Live(transfer_status_table, refresh_per_second=4) as live:
148
+ live.update(transfer_status_table)
149
+ for device_addr, status in sorted(status.items()):
150
+ chunks_col_color = "[green]" if status.success else "[bold red]"
151
+ transfer_status_table.add_row(
152
+ f"{device_addr}",
153
+ f"{chunks_col_color}{len([chunk for chunk in status.chunks if bool(chunk.acked)])}/{start_data.chunks}",
154
+ )
155
+
156
+
157
+ def wait_for_done(timeout, condition_func):
158
+ """Wait for the condition to be met."""
159
+ while timeout > 0:
160
+ if condition_func():
161
+ return True
162
+ timeout -= 0.01
163
+ time.sleep(0.01)
164
+ return False
165
+
166
+
167
+ @dataclass
168
+ class ControllerSettings:
169
+ """Class that holds controller settings."""
170
+
171
+ serial_port: str = SERIAL_PORT_DEFAULT
172
+ serial_baudrate: int = 1000000
173
+ mqtt_host: str = "localhost"
174
+ mqtt_port: int = 1883
175
+ mqtt_use_tls: bool = False
176
+ network_id: int = 1
177
+ adapter: str = "serial" # or "mqtt", "marilib-edge", "marilib-cloud"
178
+ devices: list[str] = dataclasses.field(default_factory=lambda: [])
179
+ ota_max_retries: int = OTA_MAX_RETRIES_DEFAULT
180
+ ota_timeout: float = OTA_ACK_TIMEOUT_DEFAULT
181
+ verbose: bool = False
182
+
183
+
184
+ class Controller:
185
+ """Class used to control a swarm testbed."""
186
+
187
+ def __init__(self, settings: ControllerSettings):
188
+ self.logger = LOGGER.bind(context=__name__)
189
+ self.settings = settings
190
+ self._interface: GatewayAdapterBase = None
191
+ self.status_data: dict[str, StatusType] = {}
192
+ self.started_data: list[str] = []
193
+ self.stopped_data: list[str] = []
194
+ self.chunks: list[DataChunk] = []
195
+ self.start_ota_data: StartOtaData = StartOtaData()
196
+ self.transfer_data: dict[str, TransferDataStatus] = {}
197
+ self._known_devices: dict[str, StatusType] = {}
198
+ register_parsers()
199
+ if self.settings.adapter == "cloud":
200
+ self._interface = MarilibCloudAdapter(
201
+ self.settings.mqtt_host,
202
+ self.settings.mqtt_port,
203
+ self.settings.mqtt_use_tls,
204
+ self.settings.network_id,
205
+ verbose=self.settings.verbose,
206
+ )
207
+ else:
208
+ self._interface = MarilibEdgeAdapter(
209
+ self.settings.serial_port,
210
+ self.settings.serial_baudrate,
211
+ verbose=self.settings.verbose,
212
+ )
213
+ self._interface.init(self.on_frame_received)
214
+
215
+ @property
216
+ def known_devices(self) -> dict[str, StatusType]:
217
+ """Return the known devices."""
218
+ if not self._known_devices:
219
+ wait_for_done(COMMAND_TIMEOUT, lambda: False)
220
+ self._known_devices = self.status_data
221
+ return self._known_devices
222
+
223
+ @property
224
+ def running_devices(self) -> list[str]:
225
+ """Return the running devices."""
226
+ return [
227
+ addr
228
+ for addr, status in self.known_devices.items()
229
+ if (
230
+ (
231
+ status == StatusType.Running
232
+ or status == StatusType.Programming
233
+ )
234
+ and (
235
+ not self.settings.devices or addr in self.settings.devices
236
+ )
237
+ )
238
+ ]
239
+
240
+ @property
241
+ def resetting_devices(self) -> list[str]:
242
+ """Return the resetting devices."""
243
+ return [
244
+ device_addr
245
+ for device_addr, status in self.known_devices.items()
246
+ if (
247
+ status == StatusType.Resetting
248
+ and (
249
+ not self.settings.devices
250
+ or device_addr in self.settings.devices
251
+ )
252
+ )
253
+ ]
254
+
255
+ @property
256
+ def ready_devices(self) -> list[str]:
257
+ """Return the ready devices."""
258
+ return [
259
+ device_addr
260
+ for device_addr, status in self.known_devices.items()
261
+ if (
262
+ status == StatusType.Bootloader
263
+ and (
264
+ not self.settings.devices
265
+ or device_addr in self.settings.devices
266
+ )
267
+ )
268
+ ]
269
+
270
+ @property
271
+ def interface(self) -> GatewayAdapterBase:
272
+ """Return the interface."""
273
+ return self._interface
274
+
275
+ def terminate(self):
276
+ """Terminate the controller."""
277
+ self.interface.close()
278
+
279
+ def send_payload(self, destination: int, payload: Payload):
280
+ """Send a frame to the devices."""
281
+ self.interface.send_payload(destination, payload)
282
+
283
+ def on_frame_received(self, header, packet: Packet):
284
+ """Handle the received frame."""
285
+ # if self.settings.verbose:
286
+ # print()
287
+ # print(Frame(header, packet))
288
+ if packet.payload_type < SwarmitPayloadType.SWARMIT_REQUEST_STATUS:
289
+ return
290
+ device_addr = f"{header.source:08X}"
291
+ if (
292
+ packet.payload_type
293
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
294
+ ):
295
+ self.status_data.update(
296
+ {device_addr: StatusType(packet.payload.status)}
297
+ )
298
+ elif (
299
+ packet.payload_type
300
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
301
+ ):
302
+ if device_addr in self.start_ota_data.addrs:
303
+ return
304
+ self.start_ota_data.addrs.append(device_addr)
305
+ elif (
306
+ packet.payload_type
307
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_CHUNK_ACK
308
+ ):
309
+ try:
310
+ acked = bool(
311
+ self.transfer_data[device_addr]
312
+ .chunks[packet.payload.index]
313
+ .acked
314
+ )
315
+ except (IndexError, KeyError):
316
+ self.logger.warning(
317
+ "Chunk index out of range",
318
+ device_addr=device_addr,
319
+ chunk_index=packet.payload.index,
320
+ )
321
+ return
322
+ if acked is False:
323
+ self.transfer_data[device_addr].chunks[
324
+ packet.payload.index
325
+ ].acked = 1
326
+ elif packet.payload_type in [
327
+ SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_GPIO,
328
+ SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_LOG,
329
+ ]:
330
+ if (
331
+ self.settings.devices
332
+ and device_addr not in self.settings.devices
333
+ ):
334
+ return
335
+ logger = self.logger.bind(
336
+ device_addr=device_addr,
337
+ notification=SwarmitPayloadType(packet.payload_type).name,
338
+ timestamp=packet.payload.timestamp,
339
+ data_size=packet.payload.count,
340
+ data=packet.payload.data,
341
+ )
342
+ if (
343
+ packet.payload_type
344
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_GPIO
345
+ ):
346
+ logger.info("GPIO event")
347
+ elif (
348
+ packet.payload_type
349
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_LOG
350
+ ):
351
+ logger.info("LOG event")
352
+ else:
353
+ self.logger.error(
354
+ "Unknown payload type", payload_type=packet.payload_type
355
+ )
356
+
357
+ def _live_status(
358
+ self, devices=[], timeout=STATUS_TIMEOUT, message="found"
359
+ ):
360
+ """Request the live status of the testbed."""
361
+ with Live(
362
+ generate_status(self.status_data, devices, status_message=message),
363
+ refresh_per_second=4,
364
+ ) as live:
365
+ while timeout > 0:
366
+ live.update(
367
+ generate_status(
368
+ self.status_data, devices, status_message=message
369
+ )
370
+ )
371
+ timeout -= 0.01
372
+ time.sleep(0.01)
373
+
374
+ def status(self):
375
+ """Request the status of the testbed."""
376
+ self._live_status(self.settings.devices)
377
+
378
+ def _send_start(self, device_addr: str):
379
+ payload = PayloadStartRequest()
380
+ self.send_payload(int(device_addr, 16), payload)
381
+
382
+ def start(self):
383
+ """Start the application."""
384
+ ready_devices = self.ready_devices
385
+ attempts = 0
386
+ while attempts < COMMAND_MAX_ATTEMPTS and not all(
387
+ self.status_data[addr] == StatusType.Running
388
+ for addr in ready_devices
389
+ ):
390
+ if not self.settings.devices:
391
+ self._send_start(addr_to_hex(BROADCAST_ADDRESS))
392
+ else:
393
+ for device_addr in self.settings.devices:
394
+ if device_addr not in ready_devices:
395
+ continue
396
+ self._send_start(device_addr)
397
+ attempts += 1
398
+ time.sleep(COMMAND_ATTEMPT_DELAY)
399
+ self._live_status(
400
+ ready_devices, timeout=COMMAND_TIMEOUT, message="to start"
401
+ )
402
+
403
+ def stop(self):
404
+ """Stop the application."""
405
+ stoppable_devices = self.running_devices + self.resetting_devices
406
+
407
+ attempts = 0
408
+ while attempts < COMMAND_MAX_ATTEMPTS and not all(
409
+ self.status_data[addr]
410
+ in [StatusType.Stopping, StatusType.Bootloader]
411
+ for addr in stoppable_devices
412
+ ):
413
+ if not self.settings.devices:
414
+ self.send_payload(BROADCAST_ADDRESS, PayloadStopRequest())
415
+ else:
416
+ for device_addr in self.settings.devices:
417
+ if (
418
+ device_addr not in stoppable_devices
419
+ or self.status_data[device_addr].status
420
+ in [StatusType.Stopping, StatusType.Bootloader]
421
+ ):
422
+ continue
423
+ self.send_payload(
424
+ int(device_addr, 16), PayloadStopRequest()
425
+ )
426
+ attempts += 1
427
+ time.sleep(COMMAND_ATTEMPT_DELAY)
428
+ self._live_status(
429
+ stoppable_devices, timeout=COMMAND_TIMEOUT, message="to stop"
430
+ )
431
+
432
+ def _send_reset(self, device_addr: int, location: ResetLocation):
433
+ payload = PayloadResetRequest(
434
+ pos_x=location.pos_x,
435
+ pos_y=location.pos_y,
436
+ )
437
+ self.send_payload(device_addr, payload)
438
+
439
+ def reset(self, locations: dict[str, ResetLocation]):
440
+ """Reset the application."""
441
+ ready_devices = self.ready_devices
442
+ for device_addr in self.settings.devices:
443
+ if device_addr not in ready_devices:
444
+ continue
445
+ print(
446
+ f"Resetting device {device_addr} with location {locations[device_addr]}"
447
+ )
448
+ self._send_reset(int(device_addr, 16), locations[device_addr])
449
+
450
+ def monitor(self):
451
+ """Monitor the testbed."""
452
+ self.logger.info("Monitoring testbed")
453
+ while True:
454
+ time.sleep(0.01)
455
+
456
+ def _send_message(self, device_addr: int, message: str):
457
+ payload = PayloadMessage(
458
+ count=len(message),
459
+ message=message.encode(),
460
+ )
461
+ self.send_payload(device_addr, payload)
462
+
463
+ def send_message(self, message):
464
+ """Send a message to the devices."""
465
+ running_devices = self.running_devices
466
+ if not self.settings.devices:
467
+ self._send_message(BROADCAST_ADDRESS, message)
468
+ else:
469
+ for addr in self.settings.devices:
470
+ if addr not in running_devices:
471
+ continue
472
+ self._send_message(int(addr, 16), message)
473
+
474
+ def _send_start_ota(
475
+ self, device_addr: str, devices_to_flash: set[str], firmware: bytes
476
+ ):
477
+ def is_start_ota_acknowledged():
478
+ if int(device_addr, 16) == BROADCAST_ADDRESS:
479
+ return sorted(self.start_ota_data.addrs) == sorted(
480
+ devices_to_flash
481
+ )
482
+ else:
483
+ return device_addr in self.start_ota_data.addrs
484
+
485
+ payload = PayloadOTAStartRequest(
486
+ fw_length=len(firmware),
487
+ fw_chunk_count=len(self.chunks),
488
+ )
489
+ send_time = time.time()
490
+ send = True
491
+ while (
492
+ not is_start_ota_acknowledged()
493
+ and self.start_ota_data.retries <= self.settings.ota_max_retries
494
+ ):
495
+ if send is True:
496
+ self.send_payload(int(device_addr, 16), payload)
497
+ send_time = time.time()
498
+ self.start_ota_data.retries += 1
499
+ time.sleep(0.001)
500
+ send = time.time() - send_time > self.settings.ota_timeout
501
+
502
+ def start_ota(self, firmware) -> StartOtaData:
503
+ """Start the OTA process."""
504
+ self.start_ota_data = StartOtaData()
505
+ self.chunks = []
506
+ digest = hashes.Hash(hashes.SHA256())
507
+ chunks_count = int(len(firmware) / CHUNK_SIZE) + int(
508
+ len(firmware) % CHUNK_SIZE != 0
509
+ )
510
+ for chunk_idx in range(chunks_count):
511
+ if chunk_idx == chunks_count - 1:
512
+ chunk_size = len(firmware) % CHUNK_SIZE
513
+ else:
514
+ chunk_size = CHUNK_SIZE
515
+ data = firmware[
516
+ chunk_idx * CHUNK_SIZE : chunk_idx * CHUNK_SIZE + chunk_size
517
+ ]
518
+ digest.update(data)
519
+ chunk_sha = hashes.Hash(hashes.SHA256())
520
+ chunk_sha.update(data)
521
+ self.chunks.append(
522
+ DataChunk(
523
+ index=chunk_idx,
524
+ size=chunk_size,
525
+ sha=chunk_sha.finalize()[
526
+ :8
527
+ ], # the first 8 bytes should be enough
528
+ data=data,
529
+ )
530
+ )
531
+ self.start_ota_data.fw_hash = digest.finalize()
532
+ self.start_ota_data.chunks = len(self.chunks)
533
+ devices_to_flash = self.ready_devices
534
+ if not self.settings.devices:
535
+ print("Broadcast start ota notification...")
536
+ self._send_start_ota(
537
+ addr_to_hex(BROADCAST_ADDRESS), devices_to_flash, firmware
538
+ )
539
+ else:
540
+ for addr in devices_to_flash:
541
+ print(f"Sending start ota notification to {addr}...")
542
+ self._send_start_ota(addr, devices_to_flash, firmware)
543
+ time.sleep(0.2)
544
+ return {
545
+ "ota": self.start_ota_data,
546
+ "acked": sorted(self.start_ota_data.addrs),
547
+ "missed": sorted(
548
+ set(devices_to_flash).difference(
549
+ set(self.start_ota_data.addrs)
550
+ )
551
+ ),
552
+ }
553
+
554
+ def send_chunk(
555
+ self,
556
+ chunk: DataChunk,
557
+ device_addr: str,
558
+ devices_to_flash: set[str],
559
+ ):
560
+ def is_chunk_acknowledged():
561
+ if int(device_addr, 16) == BROADCAST_ADDRESS:
562
+ return sorted(self.transfer_data.keys()) == sorted(
563
+ devices_to_flash
564
+ ) and all(
565
+ [
566
+ status.chunks[chunk.index].acked
567
+ for status in self.transfer_data.values()
568
+ ]
569
+ )
570
+ else:
571
+ return (
572
+ device_addr in self.transfer_data.keys()
573
+ and self.transfer_data[device_addr]
574
+ .chunks[chunk.index]
575
+ .acked
576
+ )
577
+
578
+ payload = PayloadOTAChunkRequest(
579
+ index=chunk.index,
580
+ count=chunk.size,
581
+ sha=chunk.sha,
582
+ chunk=chunk.data,
583
+ )
584
+ send_time = time.time()
585
+ send = True
586
+ retries_count = 0
587
+ while (
588
+ not is_chunk_acknowledged()
589
+ and retries_count <= self.settings.ota_max_retries
590
+ ):
591
+ if send is True:
592
+ if self.settings.verbose:
593
+ missing_acks = [
594
+ addr
595
+ for addr in devices_to_flash
596
+ if addr not in self.transfer_data
597
+ or not self.transfer_data[addr]
598
+ .chunks[chunk.index]
599
+ .acked
600
+ ]
601
+ print(
602
+ f"Transferring chunk {chunk.index}/{len(self.start_ota_data.chunks)} to {device_addr} "
603
+ f"- {retries_count} retries "
604
+ f"- {len(missing_acks)} missing acks: {', '.join(missing_acks) if missing_acks else 'none'}"
605
+ )
606
+ self.send_payload(int(device_addr, 16), payload)
607
+ if int(device_addr, 16) == BROADCAST_ADDRESS:
608
+ for addr in devices_to_flash:
609
+ self.transfer_data[addr].chunks[
610
+ chunk.index
611
+ ].retries = retries_count
612
+ else:
613
+ self.transfer_data[device_addr].chunks[
614
+ chunk.index
615
+ ].retries = retries_count
616
+ send_time = time.time()
617
+ retries_count += 1
618
+ time.sleep(0.001)
619
+ send = time.time() - send_time > self.settings.ota_timeout
620
+
621
+ def transfer(self, firmware, devices) -> dict[str, TransferDataStatus]:
622
+ """Transfer the firmware to the devices."""
623
+ data_size = len(firmware)
624
+ use_progress_bar = not self.settings.verbose
625
+ if use_progress_bar:
626
+ progress = tqdm(
627
+ range(0, data_size),
628
+ unit="B",
629
+ unit_scale=False,
630
+ colour="green",
631
+ ncols=100,
632
+ )
633
+ progress.set_description(
634
+ f"Loading firmware ({int(data_size / 1024)}kB)"
635
+ )
636
+ self.transfer_data = {}
637
+ for device_addr in devices:
638
+ self.transfer_data[device_addr] = TransferDataStatus()
639
+ self.transfer_data[device_addr].chunks = [
640
+ Chunk(index=f"{i:03d}", size=f"{self.chunks[i].size:03d}B")
641
+ for i in range(len(self.chunks))
642
+ ]
643
+ for chunk in self.chunks:
644
+ if not self.settings.devices:
645
+ self.send_chunk(
646
+ chunk,
647
+ addr_to_hex(BROADCAST_ADDRESS),
648
+ devices,
649
+ )
650
+ else:
651
+ for addr in devices:
652
+ self.send_chunk(chunk, addr, devices)
653
+ if use_progress_bar:
654
+ progress.update(chunk.size)
655
+ if use_progress_bar:
656
+ progress.close()
657
+ for device in devices:
658
+ device_data = self.transfer_data.get(device)
659
+ if device_data:
660
+ device_data.success = all(
661
+ chunk.acked for chunk in device_data.chunks
662
+ )
663
+ self.transfer_data[device] = device_data
664
+ return self.transfer_data