swarmit 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,620 @@
1
+ """Module containing the swarmit controller class."""
2
+
3
+ import dataclasses
4
+ import time
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ import serial
9
+ from cryptography.hazmat.primitives import hashes
10
+ from dotbot.logger import LOGGER
11
+ from dotbot.protocol import Frame, Header
12
+ from dotbot.serial_interface import SerialInterfaceException, get_default_port
13
+ from rich import print
14
+ from rich.console import Console
15
+ from rich.live import Live
16
+ from rich.table import Table
17
+ from tqdm import tqdm
18
+
19
+ from testbed.swarmit.adapter import (
20
+ GatewayAdapterBase,
21
+ MQTTAdapter,
22
+ SerialAdapter,
23
+ )
24
+ from testbed.swarmit.protocol import (
25
+ PayloadMessage,
26
+ PayloadOTAChunkRequest,
27
+ PayloadOTAStartRequest,
28
+ PayloadResetRequest,
29
+ PayloadStartRequest,
30
+ PayloadStatusRequest,
31
+ PayloadStopRequest,
32
+ StatusType,
33
+ SwarmitPayloadType,
34
+ register_parsers,
35
+ )
36
+
37
+ CHUNK_SIZE = 128
38
+ SERIAL_PORT_DEFAULT = get_default_port()
39
+
40
+
41
+ @dataclass
42
+ class DataChunk:
43
+ """Class that holds data chunks."""
44
+
45
+ index: int
46
+ size: int
47
+ data: bytes
48
+
49
+
50
+ @dataclass
51
+ class StartOtaData:
52
+ """Class that holds start ota data."""
53
+
54
+ chunks: int = 0
55
+ fw_hash: bytes = b""
56
+ ids: list[str] = dataclasses.field(default_factory=lambda: [])
57
+
58
+
59
+ @dataclass
60
+ class TransferDataStatus:
61
+ """Class that holds transfer data status for a single device."""
62
+
63
+ retries: list[int] = dataclasses.field(default_factory=lambda: [])
64
+ chunks_acked: set[int] = dataclasses.field(default_factory=lambda: set())
65
+ hashes_match: bool = False
66
+
67
+
68
+ @dataclass
69
+ class ResetLocation:
70
+ """Class that holds reset location."""
71
+
72
+ pos_x: int = 0
73
+ pos_y: int = 0
74
+
75
+ def __repr__(self):
76
+ return f"(x={self.pos_x}, y={self.pos_y})"
77
+
78
+
79
+ def print_status(status_data: dict[str, StatusType]) -> None:
80
+ """Print the status of the devices."""
81
+ print()
82
+ print(
83
+ f"{len(status_data)} device{'s' if len(status_data) > 1 else ''} found"
84
+ )
85
+ print()
86
+ status_table = Table()
87
+ status_table.add_column("Device ID", style="magenta", no_wrap=True)
88
+ status_table.add_column("Status", style="green", justify="center")
89
+ with Live(status_table, refresh_per_second=4) as live:
90
+ live.update(status_table)
91
+ for device_id, status in sorted(status_data.items()):
92
+ status_table.add_row(
93
+ f"{device_id}",
94
+ f'{"[bold cyan]" if status == StatusType.Running else "[bold green]"}{status.name}',
95
+ )
96
+
97
+
98
+ def print_start_status(
99
+ stopped_data: list[str], not_started: list[str]
100
+ ) -> None:
101
+ """Print the start status."""
102
+ print("[bold]Start status:[/]")
103
+ status_table = Table()
104
+ status_table.add_column("Device ID", style="magenta", no_wrap=True)
105
+ status_table.add_column("Status", style="green", justify="center")
106
+ with Live(status_table, refresh_per_second=4) as live:
107
+ live.update(status_table)
108
+ for device_id in sorted(stopped_data):
109
+ status_table.add_row(
110
+ f"{device_id}", "[bold green]:heavy_check_mark:[/]"
111
+ )
112
+ for device_id in sorted(not_started):
113
+ status_table.add_row(f"{device_id}", "[bold red]:x:[/]")
114
+
115
+
116
+ def print_stop_status(stopped_data: list[str], not_stopped: list[str]) -> None:
117
+ """Print the stop status."""
118
+ print("[bold]Stop status:[/]")
119
+ status_table = Table()
120
+ status_table.add_column("Device ID", style="magenta", no_wrap=True)
121
+ status_table.add_column("Status", style="green", justify="center")
122
+ with Live(status_table, refresh_per_second=4) as live:
123
+ live.update(status_table)
124
+ for device_id in sorted(stopped_data):
125
+ status_table.add_row(
126
+ f"{device_id}", "[bold green]:heavy_check_mark:[/]"
127
+ )
128
+ for device_id in sorted(not_stopped):
129
+ status_table.add_row(f"{device_id}", "[bold red]:x:[/]")
130
+
131
+
132
+ def print_transfer_status(
133
+ status: dict[str, TransferDataStatus], start_data: int
134
+ ) -> None:
135
+ """Print the transfer status."""
136
+ print()
137
+ print("[bold]Transfer status:[/]")
138
+ transfer_status_table = Table()
139
+ transfer_status_table.add_column(
140
+ "Device ID", style="magenta", no_wrap=True
141
+ )
142
+ transfer_status_table.add_column(
143
+ "Chunks acked", style="green", justify="center"
144
+ )
145
+ transfer_status_table.add_column(
146
+ "Hashes match", style="green", justify="center"
147
+ )
148
+ with Live(transfer_status_table, refresh_per_second=4) as live:
149
+ live.update(transfer_status_table)
150
+ for device_id, status in sorted(status.items()):
151
+ start_marker, stop_marker = (
152
+ ("[bold green]", "[/]")
153
+ if bool(status.hashes_match) is True
154
+ else ("[bold red]", "[/]")
155
+ )
156
+ transfer_status_table.add_row(
157
+ f"{device_id}",
158
+ f"{len(status.chunks_acked)}/{start_data.chunks}",
159
+ f"{start_marker}{bool(status.hashes_match)}{stop_marker}",
160
+ )
161
+
162
+
163
+ def wait_for_done(timeout, condition_func):
164
+ """Wait for the condition to be met."""
165
+ while timeout > 0:
166
+ if condition_func():
167
+ return True
168
+ timeout -= 0.01
169
+ time.sleep(0.01)
170
+ return False
171
+
172
+
173
+ @dataclass
174
+ class ControllerSettings:
175
+ """Class that holds controller settings."""
176
+
177
+ serial_port: str = SERIAL_PORT_DEFAULT
178
+ serial_baudrate: int = 1000000
179
+ mqtt_host: str = "argus.paris.inria.fr"
180
+ mqtt_port: int = 8883
181
+ edge: bool = False
182
+ devices: list[str] = dataclasses.field(default_factory=lambda: [])
183
+
184
+
185
+ class Controller:
186
+ """Class used to control a swarm testbed."""
187
+
188
+ def __init__(self, settings: ControllerSettings):
189
+ self.logger = LOGGER.bind(context=__name__)
190
+ self.settings = settings
191
+ self._interface: GatewayAdapterBase = None
192
+ self.status_data: dict[str, StatusType] = {}
193
+ self.started_data: list[str] = []
194
+ self.stopped_data: list[str] = []
195
+ self.chunks: list[DataChunk] = []
196
+ self.start_ota_data: StartOtaData = StartOtaData()
197
+ self.transfer_data: dict[str, TransferDataStatus] = {}
198
+ self._known_devices: dict[str, StatusType] = {}
199
+ self.expected_reply: Optional[SwarmitPayloadType] = None
200
+ register_parsers()
201
+ if self.settings.edge is True:
202
+ self._interface = MQTTAdapter(
203
+ self.settings.mqtt_host, self.settings.mqtt_port
204
+ )
205
+ else:
206
+ try:
207
+ self._interface = SerialAdapter(
208
+ self.settings.serial_port, self.settings.serial_baudrate
209
+ )
210
+ except (
211
+ SerialInterfaceException,
212
+ serial.serialutil.SerialException,
213
+ ) as exc:
214
+ console = Console()
215
+ console.print(f"[bold red]Error:[/] {exc}")
216
+ self._interface.init(self.on_data_received)
217
+
218
+ @property
219
+ def known_devices(self) -> dict[str, StatusType]:
220
+ """Return the known devices."""
221
+ if not self._known_devices:
222
+ self._known_devices = self.status()
223
+ return self._known_devices
224
+
225
+ @property
226
+ def running_devices(self) -> list[str]:
227
+ """Return the running devices."""
228
+ return [
229
+ device_id
230
+ for device_id, status in self.known_devices.items()
231
+ if (
232
+ status == StatusType.Running
233
+ and (
234
+ not self.settings.devices
235
+ or device_id in self.settings.devices
236
+ )
237
+ )
238
+ ]
239
+
240
+ @property
241
+ def resetting_devices(self) -> list[str]:
242
+ """Return the resetting devices."""
243
+ return [
244
+ device_id
245
+ for device_id, status in self.known_devices.items()
246
+ if (
247
+ status == StatusType.Resetting
248
+ and (
249
+ not self.settings.devices
250
+ or device_id in self.settings.devices
251
+ )
252
+ )
253
+ ]
254
+
255
+ @property
256
+ def ready_devices(self) -> list[str]:
257
+ """Return the ready devices."""
258
+ return [
259
+ device_id
260
+ for device_id, status in self.known_devices.items()
261
+ if (
262
+ status == StatusType.Bootloader
263
+ and (
264
+ not self.settings.devices
265
+ or device_id in self.settings.devices
266
+ )
267
+ )
268
+ ]
269
+
270
+ @property
271
+ def interface(self) -> GatewayAdapterBase:
272
+ """Return the interface."""
273
+ return self._interface
274
+
275
+ def terminate(self):
276
+ """Terminate the controller."""
277
+ self.interface.close()
278
+
279
+ def send_frame(self, frame: Frame):
280
+ """Send a frame to the devices."""
281
+ self.interface.send_data(frame.to_bytes())
282
+
283
+ def on_data_received(self, data):
284
+ frame = Frame().from_bytes(data)
285
+ if frame.payload_type < SwarmitPayloadType.SWARMIT_REQUEST_STATUS:
286
+ return
287
+ device_id = f"{frame.payload.device_id:08X}"
288
+ if (
289
+ frame.payload_type
290
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
291
+ and self.expected_reply
292
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
293
+ ):
294
+ self.status_data.update(
295
+ {device_id: StatusType(frame.payload.status)}
296
+ )
297
+ elif (
298
+ frame.payload_type
299
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_STARTED
300
+ and self.expected_reply
301
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_STARTED
302
+ ):
303
+ if device_id not in self.started_data:
304
+ self.started_data.append(device_id)
305
+ elif (
306
+ frame.payload_type
307
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_STOPPED
308
+ and self.expected_reply
309
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_STOPPED
310
+ ):
311
+ if device_id not in self.stopped_data:
312
+ self.stopped_data.append(device_id)
313
+ elif (
314
+ frame.payload_type
315
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
316
+ and self.expected_reply
317
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
318
+ ):
319
+ if device_id not in self.start_ota_data.ids:
320
+ self.start_ota_data.ids.append(device_id)
321
+ elif (
322
+ frame.payload_type
323
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_CHUNK_ACK
324
+ ):
325
+ if (
326
+ frame.payload.index
327
+ not in self.transfer_data[device_id].chunks_acked
328
+ ):
329
+ self.transfer_data[device_id].chunks_acked.add(
330
+ frame.payload.index
331
+ )
332
+ self.transfer_data[device_id].hashes_match = (
333
+ frame.payload.hashes_match
334
+ )
335
+ elif frame.payload_type in [
336
+ SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_GPIO,
337
+ SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_LOG,
338
+ ]:
339
+ if (
340
+ self.settings.devices
341
+ and device_id not in self.settings.devices
342
+ ):
343
+ return
344
+ logger = self.logger.bind(
345
+ deviceid=device_id,
346
+ notification=frame.payload_type.name,
347
+ timestamp=frame.payload.timestamp,
348
+ data_size=frame.payload.count,
349
+ data=frame.payload.data,
350
+ )
351
+ if (
352
+ frame.payload_type
353
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_GPIO
354
+ ):
355
+ logger.info("GPIO event")
356
+ elif (
357
+ frame.payload_type
358
+ == SwarmitPayloadType.SWARMIT_NOTIFICATION_EVENT_LOG
359
+ ):
360
+ logger.info("LOG event")
361
+ elif frame.payload_type != self.expected_reply:
362
+ self.logger.warning(
363
+ "Unexpected payload",
364
+ payload_type=hex(frame.payload_type),
365
+ expected=hex(self.expected_reply),
366
+ )
367
+ else:
368
+ self.logger.error(
369
+ "Unknown payload type", payload_type=frame.payload_type
370
+ )
371
+
372
+ def status(self):
373
+ """Request the status of the testbed."""
374
+ self.status_data: dict[str, StatusType] = {}
375
+ payload = PayloadStatusRequest(device_id=0)
376
+ frame = Frame(header=Header(), payload=payload)
377
+ self.expected_reply = SwarmitPayloadType.SWARMIT_NOTIFICATION_STATUS
378
+ self.send_frame(frame)
379
+ wait_for_done(1, lambda: False)
380
+ return self.status_data
381
+
382
+ def _send_start(self, device_id: str):
383
+ def is_started():
384
+ if device_id == "0":
385
+ return sorted(self.started_data) == sorted(self.ready_devices)
386
+ else:
387
+ return device_id in self.started_data
388
+
389
+ self.expected_reply = SwarmitPayloadType.SWARMIT_NOTIFICATION_STARTED
390
+ payload = PayloadStartRequest(device_id=int(device_id, base=16))
391
+ self.send_frame(Frame(header=Header(), payload=payload))
392
+ wait_for_done(3, is_started)
393
+ self.expected_reply = None
394
+
395
+ def start(self):
396
+ """Start the application."""
397
+ self.started_data = []
398
+ ready_devices = self.ready_devices
399
+ if not self.settings.devices:
400
+ self._send_start("0")
401
+ else:
402
+ for device_id in self.settings.devices:
403
+ if device_id not in ready_devices:
404
+ continue
405
+ self._send_start(device_id)
406
+ return self.started_data
407
+
408
+ def _send_stop(self, device_id: str):
409
+ stoppable_devices = self.running_devices + self.resetting_devices
410
+
411
+ def is_stopped():
412
+ if device_id == "0":
413
+ return sorted(self.stopped_data) == sorted(stoppable_devices)
414
+ else:
415
+ return device_id in self.stopped_data
416
+
417
+ self.expected_reply = SwarmitPayloadType.SWARMIT_NOTIFICATION_STOPPED
418
+ payload = PayloadStopRequest(device_id=int(device_id, base=16))
419
+ self.send_frame(Frame(header=Header(), payload=payload))
420
+ wait_for_done(3, is_stopped)
421
+ self.expected_reply = None
422
+
423
+ def stop(self):
424
+ """Stop the application."""
425
+ self.stopped_data = []
426
+ stoppable_devices = self.running_devices + self.resetting_devices
427
+ if not self.settings.devices:
428
+ self._send_stop("0")
429
+ else:
430
+ for device_id in self.settings.devices:
431
+ if device_id not in stoppable_devices:
432
+ continue
433
+ self._send_stop(device_id)
434
+ return self.stopped_data
435
+
436
+ def _send_reset(self, device_id: str, location: ResetLocation):
437
+ payload = PayloadResetRequest(
438
+ device_id=int(device_id, base=16),
439
+ pos_x=location.pos_x,
440
+ pos_y=location.pos_y,
441
+ )
442
+ self.send_frame(Frame(header=Header(), payload=payload))
443
+
444
+ def reset(self, locations: dict[str, ResetLocation]):
445
+ """Reset the application."""
446
+ ready_devices = self.ready_devices
447
+ for device_id in self.settings.devices:
448
+ if device_id not in ready_devices:
449
+ continue
450
+ print(
451
+ f"Resetting device {device_id} with location {locations[device_id]}"
452
+ )
453
+ self._send_reset(device_id, locations[device_id])
454
+
455
+ def monitor(self):
456
+ """Monitor the testbed."""
457
+ self.logger.info("Monitoring testbed")
458
+ while True:
459
+ time.sleep(0.01)
460
+
461
+ def _send_message(self, device_id, message):
462
+ payload = PayloadMessage(
463
+ device_id=int(device_id, base=16),
464
+ count=len(message),
465
+ message=message.encode(),
466
+ )
467
+ frame = Frame(header=Header(), payload=payload)
468
+ self.send_frame(frame)
469
+
470
+ def send_message(self, message):
471
+ """Send a message to the devices."""
472
+ running_devices = self.running_devices
473
+ if not self.settings.devices:
474
+ self._send_message("0", message)
475
+ else:
476
+ for device_id in self.settings.devices:
477
+ if device_id not in running_devices:
478
+ continue
479
+ self._send_message(device_id, message)
480
+
481
+ def _send_start_ota(self, device_id: str, firmware: bytes):
482
+
483
+ def is_start_ota_acknowledged():
484
+ if device_id == "0":
485
+ return sorted(self.start_ota_data.ids) == sorted(
486
+ self.ready_devices
487
+ )
488
+ else:
489
+ return device_id in self.start_ota_data.ids
490
+
491
+ payload = PayloadOTAStartRequest(
492
+ device_id=int(device_id, base=16),
493
+ fw_length=len(firmware),
494
+ fw_chunk_count=len(self.chunks),
495
+ fw_hash=self.fw_hash,
496
+ )
497
+ self.send_frame(Frame(header=Header(), payload=payload))
498
+ wait_for_done(3, is_start_ota_acknowledged)
499
+
500
+ def start_ota(self, firmware) -> StartOtaData:
501
+ """Start the OTA process."""
502
+ self.start_ota_data = StartOtaData()
503
+ self.chunks = []
504
+ digest = hashes.Hash(hashes.SHA256())
505
+ chunks_count = int(len(firmware) / CHUNK_SIZE) + int(
506
+ len(firmware) % CHUNK_SIZE != 0
507
+ )
508
+ for chunk_idx in range(chunks_count):
509
+ if chunk_idx == chunks_count - 1:
510
+ chunk_size = len(firmware) % CHUNK_SIZE
511
+ else:
512
+ chunk_size = CHUNK_SIZE
513
+ data = firmware[
514
+ chunk_idx * CHUNK_SIZE : chunk_idx * CHUNK_SIZE + chunk_size
515
+ ]
516
+ digest.update(data)
517
+ self.chunks.append(
518
+ DataChunk(
519
+ index=chunk_idx,
520
+ size=chunk_size,
521
+ data=data,
522
+ )
523
+ )
524
+ self.fw_hash = digest.finalize()
525
+ self.expected_reply = (
526
+ SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_START_ACK
527
+ )
528
+ self.start_ota_data.fw_hash = self.fw_hash
529
+ self.start_ota_data.chunks = len(self.chunks)
530
+ if not self.settings.devices:
531
+ print("Broadcast start ota notification...")
532
+ self._send_start_ota("0", firmware)
533
+ else:
534
+ for device_id in self.settings.devices:
535
+ print(f"Sending start ota notification to {device_id}...")
536
+ self._send_start_ota(device_id, firmware)
537
+ self.expected_reply = None
538
+ return self.start_ota_data
539
+
540
+ def send_chunk(self, chunk, device_id: str):
541
+
542
+ def is_chunk_acknowledged():
543
+ if device_id == "0":
544
+ return sorted(self.transfer_data.keys()) == sorted(
545
+ self.ready_devices
546
+ ) and all(
547
+ [
548
+ chunk.index in status.chunks_acked
549
+ for status in self.transfer_data.values()
550
+ ]
551
+ )
552
+ else:
553
+ return (
554
+ device_id in self.transfer_data.keys()
555
+ and chunk.index
556
+ in self.transfer_data[device_id].chunks_acked
557
+ )
558
+
559
+ send_time = time.time()
560
+ send = True
561
+ tries = 0
562
+ while tries < 3:
563
+ if is_chunk_acknowledged():
564
+ break
565
+ if send is True:
566
+ payload = PayloadOTAChunkRequest(
567
+ device_id=int(device_id, base=16),
568
+ index=chunk.index,
569
+ count=chunk.size,
570
+ chunk=chunk.data,
571
+ )
572
+ self.send_frame(Frame(header=Header(), payload=payload))
573
+ if device_id == "0":
574
+ for device_id in self.ready_devices:
575
+ self.transfer_data[device_id].retries[
576
+ chunk.index
577
+ ] = tries
578
+ else:
579
+ self.transfer_data[device_id].retries[chunk.index] = tries
580
+ tries += 1
581
+ time.sleep(0.01)
582
+ send_time = time.time()
583
+ time.sleep(0.001)
584
+ send = time.time() - send_time > 1
585
+
586
+ def transfer(self, firmware):
587
+ """Transfer the firmware to the devices."""
588
+ data_size = len(firmware)
589
+ progress = tqdm(
590
+ range(0, data_size),
591
+ unit="B",
592
+ unit_scale=False,
593
+ colour="green",
594
+ ncols=100,
595
+ )
596
+ progress.set_description(
597
+ f"Loading firmware ({int(data_size / 1024)}kB)"
598
+ )
599
+ self.expected_reply = (
600
+ SwarmitPayloadType.SWARMIT_NOTIFICATION_OTA_CHUNK_ACK
601
+ )
602
+ self.transfer_data = {}
603
+ if not self.settings.devices:
604
+ for device_id in self.ready_devices:
605
+ self.transfer_data[device_id] = TransferDataStatus()
606
+ self.transfer_data[device_id].retries = [0] * len(self.chunks)
607
+ else:
608
+ for device_id in self.settings.devices:
609
+ self.transfer_data[device_id] = TransferDataStatus()
610
+ self.transfer_data[device_id].retries = [0] * len(self.chunks)
611
+ for chunk in self.chunks:
612
+ if not self.settings.devices:
613
+ self.send_chunk(chunk, "0")
614
+ else:
615
+ for device_id in self.settings.devices:
616
+ self.send_chunk(chunk, device_id)
617
+ progress.update(chunk.size)
618
+ progress.close()
619
+ self.expected_reply = None
620
+ return self.transfer_data