marilib-pkg 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/frames.py +23 -0
- examples/mari_cloud.py +72 -0
- examples/mari_cloud_minimal.py +38 -0
- examples/mari_edge.py +73 -0
- examples/mari_edge_minimal.py +37 -0
- examples/mari_edge_stats.py +156 -0
- examples/uart.py +35 -0
- marilib/__init__.py +10 -0
- marilib/communication_adapter.py +212 -0
- marilib/latency.py +78 -0
- marilib/logger.py +211 -0
- marilib/mari_protocol.py +76 -0
- marilib/marilib.py +35 -0
- marilib/marilib_cloud.py +193 -0
- marilib/marilib_edge.py +248 -0
- marilib/model.py +393 -0
- marilib/protocol.py +109 -0
- marilib/serial_hdlc.py +228 -0
- marilib/serial_uart.py +84 -0
- marilib/tui.py +13 -0
- marilib/tui_cloud.py +158 -0
- marilib/tui_edge.py +185 -0
- marilib_pkg-0.6.0.dist-info/METADATA +57 -0
- marilib_pkg-0.6.0.dist-info/RECORD +30 -0
- marilib_pkg-0.6.0.dist-info/WHEEL +4 -0
- marilib_pkg-0.6.0.dist-info/licenses/AUTHORS +2 -0
- marilib_pkg-0.6.0.dist-info/licenses/LICENSE +28 -0
- tests/__init__.py +0 -0
- tests/test_hdlc.py +76 -0
- tests/test_protocol.py +35 -0
marilib/latency.py
ADDED
@@ -0,0 +1,78 @@
|
|
1
|
+
import struct
|
2
|
+
import threading
|
3
|
+
import time
|
4
|
+
from typing import TYPE_CHECKING
|
5
|
+
import math
|
6
|
+
|
7
|
+
from marilib.mari_protocol import Frame
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from marilib.marilib import MariLib
|
11
|
+
|
12
|
+
LATENCY_PACKET_MAGIC = b"\x4c\x54" # "LT" for Latency Test
|
13
|
+
|
14
|
+
|
15
|
+
class LatencyTester:
|
16
|
+
"""A thread-based class to periodically test latency to all nodes."""
|
17
|
+
|
18
|
+
def __init__(self, marilib: "MariLib", interval: float = 10.0):
|
19
|
+
self.marilib = marilib
|
20
|
+
self.interval = interval
|
21
|
+
self._stop_event = threading.Event()
|
22
|
+
self._thread = threading.Thread(target=self._run, daemon=True)
|
23
|
+
|
24
|
+
def start(self):
|
25
|
+
"""Starts the latency testing thread."""
|
26
|
+
print("[yellow]Latency tester started.[/]")
|
27
|
+
self._thread.start()
|
28
|
+
|
29
|
+
def stop(self):
|
30
|
+
"""Stops the latency testing thread."""
|
31
|
+
self._stop_event.set()
|
32
|
+
self._thread.join()
|
33
|
+
print("[yellow]Latency tester stopped.[/]")
|
34
|
+
|
35
|
+
def _run(self):
|
36
|
+
"""The main loop for the testing thread."""
|
37
|
+
while not self._stop_event.is_set():
|
38
|
+
if not self.marilib.gateway.nodes:
|
39
|
+
time.sleep(self.interval)
|
40
|
+
continue
|
41
|
+
|
42
|
+
for node in list(self.marilib.gateway.nodes):
|
43
|
+
if self._stop_event.is_set():
|
44
|
+
break
|
45
|
+
self.send_latency_request(node.address)
|
46
|
+
time.sleep(self.interval / len(self.marilib.gateway.nodes))
|
47
|
+
|
48
|
+
def send_latency_request(self, address: int):
|
49
|
+
"""Sends a latency request packet to a specific address."""
|
50
|
+
|
51
|
+
payload = LATENCY_PACKET_MAGIC + struct.pack("<d", time.time())
|
52
|
+
self.marilib.send_frame(address, payload)
|
53
|
+
|
54
|
+
def handle_response(self, frame: Frame):
|
55
|
+
"""
|
56
|
+
Processes a latency response frame.
|
57
|
+
This should be called when a LATENCY_DATA event is received.
|
58
|
+
"""
|
59
|
+
if not frame.payload.startswith(LATENCY_PACKET_MAGIC):
|
60
|
+
return
|
61
|
+
try:
|
62
|
+
# Unpack the original timestamp from the payload
|
63
|
+
original_ts = struct.unpack("<d", frame.payload[2:10])[0]
|
64
|
+
rtt = time.time() - original_ts
|
65
|
+
if math.isnan(rtt) or math.isinf(rtt):
|
66
|
+
return # Ignore corrupted/invalid packets
|
67
|
+
if rtt < 0 or rtt > 5.0:
|
68
|
+
return # Ignore this outlier
|
69
|
+
|
70
|
+
node = self.marilib.gateway.get_node(frame.header.source)
|
71
|
+
if node:
|
72
|
+
# Update statistics for both the specific node and the whole gateway
|
73
|
+
node.latency_stats.add_latency(rtt)
|
74
|
+
self.marilib.gateway.latency_stats.add_latency(rtt)
|
75
|
+
|
76
|
+
except (struct.error, IndexError):
|
77
|
+
# Ignore packets that are too short or malformed
|
78
|
+
pass
|
marilib/logger.py
ADDED
@@ -0,0 +1,211 @@
|
|
1
|
+
import csv
|
2
|
+
import os
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
from datetime import datetime, timedelta
|
5
|
+
from typing import IO, List, Dict
|
6
|
+
|
7
|
+
from marilib.model import MariGateway, MariNode
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class MetricsLogger:
|
12
|
+
"""
|
13
|
+
A metrics logger that saves statistics to CSV files with log rotation.
|
14
|
+
"""
|
15
|
+
|
16
|
+
log_dir_base: str = "logs"
|
17
|
+
rotation_interval_minutes: int = 1440 # 1 day
|
18
|
+
already_logged_setup_parameters: bool = False
|
19
|
+
log_interval_seconds: float = 1.0
|
20
|
+
last_log_time: Dict[int, datetime] = field(default_factory=dict)
|
21
|
+
|
22
|
+
def __post_init__(self):
|
23
|
+
"""
|
24
|
+
Initializes the logger with rotation and setup logging capabilities.
|
25
|
+
"""
|
26
|
+
try:
|
27
|
+
self.rotation_interval = timedelta(minutes=self.rotation_interval_minutes)
|
28
|
+
|
29
|
+
self.start_time = datetime.now()
|
30
|
+
self.run_timestamp = self.start_time.strftime("%Y%m%d_%H%M%S")
|
31
|
+
self.log_dir = os.path.join(self.log_dir_base, f"run_{self.run_timestamp}")
|
32
|
+
os.makedirs(self.log_dir, exist_ok=True)
|
33
|
+
|
34
|
+
self._gateway_file: IO[str] | None = None
|
35
|
+
self._nodes_file: IO[str] | None = None
|
36
|
+
self._events_file: IO[str] | None = None
|
37
|
+
self._gateway_writer = None
|
38
|
+
self._nodes_writer = None
|
39
|
+
self._events_writer = None
|
40
|
+
self.segment_start_time: datetime | None = None
|
41
|
+
|
42
|
+
# Open events log file
|
43
|
+
events_path = os.path.join(self.log_dir, "log_events.csv")
|
44
|
+
self._events_file = open(events_path, "w", newline="", encoding="utf-8")
|
45
|
+
self._events_writer = csv.writer(self._events_file)
|
46
|
+
self._events_writer.writerow(
|
47
|
+
["timestamp", "gateway_address", "node_address", "event_name", "event_tag"]
|
48
|
+
)
|
49
|
+
|
50
|
+
self._open_new_segment()
|
51
|
+
self.active = True
|
52
|
+
|
53
|
+
except (IOError, OSError) as e:
|
54
|
+
print(f"Error: Failed to initialize logger: {e}")
|
55
|
+
self.active = False
|
56
|
+
|
57
|
+
def log_setup_parameters(self, params: Dict[str, any] | None):
|
58
|
+
"""Creates and writes test setup parameters to metrics_setup.csv."""
|
59
|
+
if not params or self.already_logged_setup_parameters:
|
60
|
+
return
|
61
|
+
# only log setup parameters once
|
62
|
+
self.already_logged_setup_parameters = True
|
63
|
+
|
64
|
+
setup_path = os.path.join(self.log_dir, "metrics_setup.csv")
|
65
|
+
with open(setup_path, "w", newline="", encoding="utf-8") as f:
|
66
|
+
writer = csv.writer(f)
|
67
|
+
writer.writerow(["param", "value"])
|
68
|
+
writer.writerow(["start_time", self.start_time.isoformat()])
|
69
|
+
for key, value in params.items():
|
70
|
+
writer.writerow([key, value])
|
71
|
+
|
72
|
+
def _open_new_segment(self):
|
73
|
+
self._close_segment_files()
|
74
|
+
|
75
|
+
self.segment_start_time = datetime.now()
|
76
|
+
segment_ts = self.segment_start_time.strftime("%H%M%S")
|
77
|
+
|
78
|
+
gateway_path = os.path.join(self.log_dir, f"gateway_metrics_{segment_ts}.csv")
|
79
|
+
nodes_path = os.path.join(self.log_dir, f"node_metrics_{segment_ts}.csv")
|
80
|
+
|
81
|
+
self._gateway_file = open(gateway_path, "w", newline="", encoding="utf-8")
|
82
|
+
self._gateway_writer = csv.writer(self._gateway_file)
|
83
|
+
gateway_header = [
|
84
|
+
"timestamp",
|
85
|
+
"gateway_address",
|
86
|
+
"schedule_id",
|
87
|
+
"connected_nodes",
|
88
|
+
"tx_total",
|
89
|
+
"rx_total",
|
90
|
+
"tx_rate_1s",
|
91
|
+
"rx_rate_1s",
|
92
|
+
"avg_latency_ms",
|
93
|
+
]
|
94
|
+
self._gateway_writer.writerow(gateway_header)
|
95
|
+
|
96
|
+
self._nodes_file = open(nodes_path, "w", newline="", encoding="utf-8")
|
97
|
+
self._nodes_writer = csv.writer(self._nodes_file)
|
98
|
+
nodes_header = [
|
99
|
+
"timestamp",
|
100
|
+
"gateway_address",
|
101
|
+
"node_address",
|
102
|
+
"is_alive",
|
103
|
+
"tx_total",
|
104
|
+
"rx_total",
|
105
|
+
"tx_rate_1s",
|
106
|
+
"rx_rate_1s",
|
107
|
+
"success_rate_30s",
|
108
|
+
"success_rate_total",
|
109
|
+
"pdr_downlink",
|
110
|
+
"pdr_uplink",
|
111
|
+
"rssi_dbm_5s",
|
112
|
+
"last_latency_ms",
|
113
|
+
"avg_latency_ms",
|
114
|
+
]
|
115
|
+
self._nodes_writer.writerow(nodes_header)
|
116
|
+
|
117
|
+
def _check_for_rotation(self):
|
118
|
+
if datetime.now() - self.segment_start_time >= self.rotation_interval:
|
119
|
+
self._open_new_segment()
|
120
|
+
|
121
|
+
def _log_common(self):
|
122
|
+
if not self.active:
|
123
|
+
return False
|
124
|
+
self._check_for_rotation()
|
125
|
+
return True
|
126
|
+
|
127
|
+
def log_periodic_metrics(self, gateway: MariGateway, nodes: List[MariNode]):
|
128
|
+
last_log_time = self.last_log_time.get(gateway.info.address, self.segment_start_time)
|
129
|
+
if datetime.now() - last_log_time >= timedelta(seconds=self.log_interval_seconds):
|
130
|
+
self.log_gateway_metrics(gateway)
|
131
|
+
self.log_all_nodes_metrics(nodes)
|
132
|
+
self.last_log_time[gateway.info.address] = datetime.now()
|
133
|
+
|
134
|
+
def log_gateway_metrics(self, gateway: MariGateway):
|
135
|
+
if not self._log_common() or self._gateway_writer is None:
|
136
|
+
return
|
137
|
+
|
138
|
+
timestamp = datetime.now().isoformat()
|
139
|
+
row = [
|
140
|
+
timestamp,
|
141
|
+
f"0x{gateway.info.address:016X}",
|
142
|
+
gateway.info.schedule_id,
|
143
|
+
len(gateway.nodes),
|
144
|
+
gateway.stats.sent_count(include_test_packets=False),
|
145
|
+
gateway.stats.received_count(include_test_packets=False),
|
146
|
+
gateway.stats.sent_count(1, include_test_packets=False),
|
147
|
+
gateway.stats.received_count(1, include_test_packets=False),
|
148
|
+
f"{gateway.latency_stats.avg_ms:.2f}",
|
149
|
+
]
|
150
|
+
self._gateway_writer.writerow(row)
|
151
|
+
|
152
|
+
def log_all_nodes_metrics(self, nodes: List[MariNode]):
|
153
|
+
"""Writes metrics for all nodes, handling rotation."""
|
154
|
+
if not self._log_common() or self._nodes_writer is None:
|
155
|
+
return
|
156
|
+
|
157
|
+
timestamp = datetime.now().isoformat()
|
158
|
+
for node in nodes:
|
159
|
+
row = [
|
160
|
+
timestamp,
|
161
|
+
f"0x{node.gateway_address:016X}",
|
162
|
+
f"0x{node.address:016X}",
|
163
|
+
node.is_alive,
|
164
|
+
node.stats.sent_count(include_test_packets=False),
|
165
|
+
node.stats.received_count(include_test_packets=False),
|
166
|
+
node.stats.sent_count(1, include_test_packets=False),
|
167
|
+
node.stats.received_count(1, include_test_packets=False),
|
168
|
+
f"{node.stats.success_rate(30):.2%}",
|
169
|
+
f"{node.stats.success_rate():.2%}",
|
170
|
+
f"{node.pdr_downlink:.2%}",
|
171
|
+
f"{node.pdr_uplink:.2%}",
|
172
|
+
node.stats.received_rssi_dbm(5),
|
173
|
+
f"{node.latency_stats.last_ms:.2f}",
|
174
|
+
f"{node.latency_stats.avg_ms:.2f}",
|
175
|
+
]
|
176
|
+
self._nodes_writer.writerow(row)
|
177
|
+
|
178
|
+
def log_event(
|
179
|
+
self, gateway_address: int, node_address: int, event_name: str, event_tag: str = ""
|
180
|
+
):
|
181
|
+
"""Logs an event to the events log file."""
|
182
|
+
if not self.active or self._events_writer is None:
|
183
|
+
return
|
184
|
+
|
185
|
+
timestamp = datetime.now().isoformat()
|
186
|
+
row = [
|
187
|
+
timestamp,
|
188
|
+
f"0x{gateway_address:016X}",
|
189
|
+
f"0x{node_address:016X}",
|
190
|
+
event_name,
|
191
|
+
event_tag,
|
192
|
+
]
|
193
|
+
self._events_writer.writerow(row)
|
194
|
+
if self._events_file:
|
195
|
+
self._events_file.flush()
|
196
|
+
|
197
|
+
def _close_segment_files(self):
|
198
|
+
if self._gateway_file and not self._gateway_file.closed:
|
199
|
+
self._gateway_file.close()
|
200
|
+
if self._nodes_file and not self._nodes_file.closed:
|
201
|
+
self._nodes_file.close()
|
202
|
+
|
203
|
+
def close(self):
|
204
|
+
if not self.active:
|
205
|
+
return
|
206
|
+
|
207
|
+
self._close_segment_files()
|
208
|
+
if self._events_file and not self._events_file.closed:
|
209
|
+
self._events_file.close()
|
210
|
+
print(f"\nMetrics saved to: {self.log_dir}")
|
211
|
+
self.active = False
|
marilib/mari_protocol.py
ADDED
@@ -0,0 +1,76 @@
|
|
1
|
+
import dataclasses
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
from marilib.protocol import Packet, PacketFieldMetadata, PacketType
|
5
|
+
|
6
|
+
MARI_PROTOCOL_VERSION = 2
|
7
|
+
MARI_BROADCAST_ADDRESS = 0xFFFFFFFFFFFFFFFF
|
8
|
+
MARI_NET_ID_DEFAULT = 0x0001
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class HeaderStats(Packet):
|
13
|
+
"""Dataclass that holds MAC header stats."""
|
14
|
+
|
15
|
+
metadata: list[PacketFieldMetadata] = dataclasses.field(
|
16
|
+
default_factory=lambda: [
|
17
|
+
PacketFieldMetadata(name="rssi", disp="rssi", length=1),
|
18
|
+
]
|
19
|
+
)
|
20
|
+
rssi: int = 0
|
21
|
+
|
22
|
+
@property
|
23
|
+
def rssi_dbm(self) -> int:
|
24
|
+
if self.rssi > 127:
|
25
|
+
return self.rssi - 255
|
26
|
+
return self.rssi
|
27
|
+
|
28
|
+
|
29
|
+
@dataclass
|
30
|
+
class Header(Packet):
|
31
|
+
"""Dataclass that holds MAC header fields."""
|
32
|
+
|
33
|
+
metadata: list[PacketFieldMetadata] = dataclasses.field(
|
34
|
+
default_factory=lambda: [
|
35
|
+
PacketFieldMetadata(name="version", disp="ver.", length=1),
|
36
|
+
PacketFieldMetadata(name="type_", disp="type", length=1),
|
37
|
+
PacketFieldMetadata(name="network_id", disp="net", length=2),
|
38
|
+
PacketFieldMetadata(name="destination", disp="dst", length=8),
|
39
|
+
PacketFieldMetadata(name="source", disp="src", length=8),
|
40
|
+
]
|
41
|
+
)
|
42
|
+
version: int = MARI_PROTOCOL_VERSION
|
43
|
+
type_: int = PacketType.DATA
|
44
|
+
network_id: int = MARI_NET_ID_DEFAULT
|
45
|
+
destination: int = MARI_BROADCAST_ADDRESS
|
46
|
+
source: int = 0x0000000000000000
|
47
|
+
|
48
|
+
def __repr__(self):
|
49
|
+
type_ = PacketType(self.type_).name
|
50
|
+
return f"Header(version={self.version}, type_={type_}, network_id=0x{self.network_id:04x}, destination=0x{self.destination:016x}, source=0x{self.source:016x})"
|
51
|
+
|
52
|
+
|
53
|
+
@dataclass
|
54
|
+
class Frame:
|
55
|
+
"""Data class that holds a payload packet."""
|
56
|
+
|
57
|
+
header: Header = None
|
58
|
+
stats: HeaderStats = dataclasses.field(default_factory=HeaderStats)
|
59
|
+
payload: bytes = b""
|
60
|
+
|
61
|
+
def from_bytes(self, bytes_):
|
62
|
+
self.header = Header().from_bytes(bytes_[0:20])
|
63
|
+
if len(bytes_) > 20:
|
64
|
+
self.stats = HeaderStats().from_bytes(bytes_[20:21])
|
65
|
+
if len(bytes_) > 21:
|
66
|
+
self.payload = bytes_[21:]
|
67
|
+
return self
|
68
|
+
|
69
|
+
def to_bytes(self, byteorder="little") -> bytes:
|
70
|
+
header_bytes = self.header.to_bytes(byteorder)
|
71
|
+
stats_bytes = self.stats.to_bytes(byteorder)
|
72
|
+
return header_bytes + stats_bytes + self.payload
|
73
|
+
|
74
|
+
def __repr__(self):
|
75
|
+
header_no_metadata = dataclasses.replace(self.header, metadata=[])
|
76
|
+
return f"Frame(header={header_no_metadata}, payload={self.payload})"
|
marilib/marilib.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
|
3
|
+
from marilib.model import MariNode
|
4
|
+
|
5
|
+
|
6
|
+
class MarilibBase(ABC):
|
7
|
+
"""Base class for Marilib applications."""
|
8
|
+
|
9
|
+
@abstractmethod
|
10
|
+
def update(self):
|
11
|
+
"""Recurrent bookkeeping. Don't forget to call this periodically on your main loop."""
|
12
|
+
|
13
|
+
@abstractmethod
|
14
|
+
def nodes(self) -> list[MariNode]:
|
15
|
+
"""Returns all nodes in the network."""
|
16
|
+
|
17
|
+
@abstractmethod
|
18
|
+
def add_node(self, address: int, gateway_address: int = None) -> MariNode | None:
|
19
|
+
"""Adds a node to the network."""
|
20
|
+
|
21
|
+
@abstractmethod
|
22
|
+
def remove_node(self, address: int) -> MariNode | None:
|
23
|
+
"""Removes a node from the network."""
|
24
|
+
|
25
|
+
@abstractmethod
|
26
|
+
def send_frame(self, dst: int, payload: bytes):
|
27
|
+
"""Sends a frame to the network."""
|
28
|
+
|
29
|
+
@abstractmethod
|
30
|
+
def render_tui(self):
|
31
|
+
"""Renders the TUI."""
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
def close_tui(self):
|
35
|
+
"""Closes the TUI."""
|
marilib/marilib_cloud.py
ADDED
@@ -0,0 +1,193 @@
|
|
1
|
+
import threading
|
2
|
+
from dataclasses import dataclass, field
|
3
|
+
from datetime import datetime
|
4
|
+
from typing import Any, Callable
|
5
|
+
|
6
|
+
from marilib.latency import LatencyTester
|
7
|
+
from marilib.mari_protocol import Frame, Header
|
8
|
+
from marilib.model import (
|
9
|
+
EdgeEvent,
|
10
|
+
GatewayInfo,
|
11
|
+
MariGateway,
|
12
|
+
MariNode,
|
13
|
+
NodeInfoCloud,
|
14
|
+
)
|
15
|
+
from marilib.communication_adapter import MQTTAdapter
|
16
|
+
from marilib.marilib import MarilibBase
|
17
|
+
from marilib.tui_cloud import MarilibTUICloud
|
18
|
+
|
19
|
+
LOAD_PACKET_PAYLOAD = b"L"
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class MarilibCloud(MarilibBase):
|
24
|
+
"""
|
25
|
+
The MarilibCloud class runs in a computer.
|
26
|
+
It is used to communicate with a Mari radio gateway (nRF5340) via MQTT.
|
27
|
+
"""
|
28
|
+
|
29
|
+
cb_application: Callable[[EdgeEvent, MariNode | Frame | GatewayInfo], None]
|
30
|
+
mqtt_interface: MQTTAdapter
|
31
|
+
network_id: int
|
32
|
+
tui: MarilibTUICloud | None = None
|
33
|
+
|
34
|
+
logger: Any | None = None
|
35
|
+
gateways: dict[int, MariGateway] = field(default_factory=dict)
|
36
|
+
lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
37
|
+
latency_tester: LatencyTester | None = None
|
38
|
+
|
39
|
+
started_ts: datetime = field(default_factory=datetime.now)
|
40
|
+
last_received_mqtt_data_ts: datetime = field(default_factory=datetime.now)
|
41
|
+
main_file: str | None = None
|
42
|
+
|
43
|
+
def __post_init__(self):
|
44
|
+
self.setup_params = {
|
45
|
+
"main_file": self.main_file or "unknown",
|
46
|
+
"mqtt_host": self.mqtt_interface.host,
|
47
|
+
"mqtt_port": self.mqtt_interface.port,
|
48
|
+
"network_id": self.network_id_str,
|
49
|
+
}
|
50
|
+
self.mqtt_interface.set_network_id(self.network_id_str)
|
51
|
+
self.mqtt_interface.set_on_data_received(self.on_mqtt_data_received)
|
52
|
+
self.mqtt_interface.init()
|
53
|
+
if self.logger:
|
54
|
+
self.logger.log_setup_parameters(self.setup_params)
|
55
|
+
|
56
|
+
# ============================ MarilibBase methods =========================
|
57
|
+
|
58
|
+
def update(self):
|
59
|
+
"""Recurrent bookkeeping. Don't forget to call this periodically on your main loop."""
|
60
|
+
with self.lock:
|
61
|
+
# remove dead gateways
|
62
|
+
self.gateways = {
|
63
|
+
addr: gateway for addr, gateway in self.gateways.items() if gateway.is_alive
|
64
|
+
}
|
65
|
+
# update each gateway
|
66
|
+
for gateway in self.gateways.values():
|
67
|
+
gateway.update()
|
68
|
+
if self.logger:
|
69
|
+
self.logger.log_periodic_metrics(gateway, gateway.nodes)
|
70
|
+
|
71
|
+
@property
|
72
|
+
def nodes(self) -> list[MariNode]:
|
73
|
+
return [node for gateway in self.gateways.values() for node in gateway.nodes]
|
74
|
+
|
75
|
+
def add_node(self, address: int, gateway_address: int = None) -> MariNode | None:
|
76
|
+
with self.lock:
|
77
|
+
gateway = self.gateways.get(gateway_address)
|
78
|
+
if gateway:
|
79
|
+
node = gateway.add_node(address)
|
80
|
+
return node
|
81
|
+
return None
|
82
|
+
|
83
|
+
def remove_node(self, address: int, gateway_address: int = None) -> MariNode | None:
|
84
|
+
with self.lock:
|
85
|
+
gateway = self.gateways.get(gateway_address)
|
86
|
+
if gateway:
|
87
|
+
node = gateway.remove_node(address)
|
88
|
+
return node
|
89
|
+
return None
|
90
|
+
|
91
|
+
def send_frame(self, dst: int, payload: bytes):
|
92
|
+
"""
|
93
|
+
Sends a frame to a gateway via MQTT.
|
94
|
+
Consists in publishing a message to the /mari/{network_id}/to_edge topic.
|
95
|
+
"""
|
96
|
+
mari_frame = Frame(Header(destination=dst), payload=payload)
|
97
|
+
|
98
|
+
self.mqtt_interface.send_data_to_edge(
|
99
|
+
EdgeEvent.to_bytes(EdgeEvent.NODE_DATA) + mari_frame.to_bytes()
|
100
|
+
)
|
101
|
+
|
102
|
+
def render_tui(self):
|
103
|
+
if self.tui:
|
104
|
+
self.tui.render(self)
|
105
|
+
|
106
|
+
def close_tui(self):
|
107
|
+
if self.tui:
|
108
|
+
self.tui.close()
|
109
|
+
|
110
|
+
# ============================ MarilibCloud methods =========================
|
111
|
+
|
112
|
+
@property
|
113
|
+
def network_id_str(self) -> str:
|
114
|
+
return f"{self.network_id:04X}"
|
115
|
+
|
116
|
+
# ============================ Callbacks ===================================
|
117
|
+
|
118
|
+
def handle_mqtt_data(self, data: bytes) -> tuple[bool, EdgeEvent, Any]:
|
119
|
+
"""
|
120
|
+
Handles the MQTT data received from the MQTT broker:
|
121
|
+
- parses the event
|
122
|
+
- updates node or gateway information
|
123
|
+
- returns the event type and data (if any)
|
124
|
+
"""
|
125
|
+
|
126
|
+
if len(data) < 1:
|
127
|
+
return False, EdgeEvent.UNKNOWN, None
|
128
|
+
|
129
|
+
self.last_received_mqtt_data_ts = datetime.now()
|
130
|
+
|
131
|
+
try:
|
132
|
+
event_type = EdgeEvent(data[0])
|
133
|
+
except ValueError:
|
134
|
+
return False, EdgeEvent.UNKNOWN, None
|
135
|
+
|
136
|
+
try:
|
137
|
+
if event_type == EdgeEvent.NODE_JOINED:
|
138
|
+
node_info = NodeInfoCloud().from_bytes(data[1:])
|
139
|
+
if node := self.add_node(node_info.address, node_info.gateway_address):
|
140
|
+
return True, EdgeEvent.NODE_JOINED, node_info
|
141
|
+
|
142
|
+
elif event_type == EdgeEvent.NODE_LEFT:
|
143
|
+
node_info = NodeInfoCloud().from_bytes(data[1:])
|
144
|
+
if node := self.remove_node(node_info.address, node_info.gateway_address):
|
145
|
+
return True, EdgeEvent.NODE_LEFT, node_info
|
146
|
+
|
147
|
+
elif event_type == EdgeEvent.NODE_KEEP_ALIVE:
|
148
|
+
node_info = NodeInfoCloud().from_bytes(data[1:])
|
149
|
+
gateway = self.gateways.get(node_info.gateway_address)
|
150
|
+
if gateway:
|
151
|
+
gateway.update_node_liveness(node_info.address)
|
152
|
+
return True, EdgeEvent.NODE_KEEP_ALIVE, node_info
|
153
|
+
|
154
|
+
elif event_type == EdgeEvent.GATEWAY_INFO:
|
155
|
+
gateway_info = GatewayInfo().from_bytes(data[1:])
|
156
|
+
gateway = self.gateways.get(gateway_info.address)
|
157
|
+
if not gateway:
|
158
|
+
# we are learning about a new gateway, so instantiate it and add it to the list
|
159
|
+
gateway = MariGateway(info=gateway_info)
|
160
|
+
self.gateways[gateway.info.address] = gateway
|
161
|
+
else:
|
162
|
+
gateway.set_info(gateway_info)
|
163
|
+
return True, EdgeEvent.GATEWAY_INFO, gateway_info
|
164
|
+
|
165
|
+
elif event_type == EdgeEvent.NODE_DATA:
|
166
|
+
frame = Frame().from_bytes(data[1:])
|
167
|
+
|
168
|
+
gateway_address = frame.header.destination
|
169
|
+
node_address = frame.header.source
|
170
|
+
gateway = self.gateways.get(gateway_address)
|
171
|
+
node = gateway.get_node(node_address)
|
172
|
+
if not gateway or not node:
|
173
|
+
return False, EdgeEvent.UNKNOWN, None
|
174
|
+
|
175
|
+
gateway.update_node_liveness(node_address)
|
176
|
+
gateway.register_received_frame(frame, is_test_packet=False)
|
177
|
+
return True, EdgeEvent.NODE_DATA, frame
|
178
|
+
|
179
|
+
except Exception as e:
|
180
|
+
print(f"Error handling MQTT data: {e}")
|
181
|
+
|
182
|
+
# fallback result in case of error
|
183
|
+
return False, EdgeEvent.UNKNOWN, None
|
184
|
+
|
185
|
+
def on_mqtt_data_received(self, data: bytes):
|
186
|
+
res, event_type, event_data = self.handle_mqtt_data(data)
|
187
|
+
if res:
|
188
|
+
if self.logger and event_type in [EdgeEvent.NODE_JOINED, EdgeEvent.NODE_LEFT]:
|
189
|
+
# TODO: update the logging system to also support GATEWAY_INFO events from multiple gateways
|
190
|
+
self.logger.log_event(
|
191
|
+
event_data.gateway_address, event_data.address, event_type.name
|
192
|
+
)
|
193
|
+
self.cb_application(event_type, event_data)
|