marilib-pkg 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
marilib/latency.py ADDED
@@ -0,0 +1,78 @@
1
+ import struct
2
+ import threading
3
+ import time
4
+ from typing import TYPE_CHECKING
5
+ import math
6
+
7
+ from marilib.mari_protocol import Frame
8
+
9
+ if TYPE_CHECKING:
10
+ from marilib.marilib import MariLib
11
+
12
+ LATENCY_PACKET_MAGIC = b"\x4c\x54" # "LT" for Latency Test
13
+
14
+
15
+ class LatencyTester:
16
+ """A thread-based class to periodically test latency to all nodes."""
17
+
18
+ def __init__(self, marilib: "MariLib", interval: float = 10.0):
19
+ self.marilib = marilib
20
+ self.interval = interval
21
+ self._stop_event = threading.Event()
22
+ self._thread = threading.Thread(target=self._run, daemon=True)
23
+
24
+ def start(self):
25
+ """Starts the latency testing thread."""
26
+ print("[yellow]Latency tester started.[/]")
27
+ self._thread.start()
28
+
29
+ def stop(self):
30
+ """Stops the latency testing thread."""
31
+ self._stop_event.set()
32
+ self._thread.join()
33
+ print("[yellow]Latency tester stopped.[/]")
34
+
35
+ def _run(self):
36
+ """The main loop for the testing thread."""
37
+ while not self._stop_event.is_set():
38
+ if not self.marilib.gateway.nodes:
39
+ time.sleep(self.interval)
40
+ continue
41
+
42
+ for node in list(self.marilib.gateway.nodes):
43
+ if self._stop_event.is_set():
44
+ break
45
+ self.send_latency_request(node.address)
46
+ time.sleep(self.interval / len(self.marilib.gateway.nodes))
47
+
48
+ def send_latency_request(self, address: int):
49
+ """Sends a latency request packet to a specific address."""
50
+
51
+ payload = LATENCY_PACKET_MAGIC + struct.pack("<d", time.time())
52
+ self.marilib.send_frame(address, payload)
53
+
54
+ def handle_response(self, frame: Frame):
55
+ """
56
+ Processes a latency response frame.
57
+ This should be called when a LATENCY_DATA event is received.
58
+ """
59
+ if not frame.payload.startswith(LATENCY_PACKET_MAGIC):
60
+ return
61
+ try:
62
+ # Unpack the original timestamp from the payload
63
+ original_ts = struct.unpack("<d", frame.payload[2:10])[0]
64
+ rtt = time.time() - original_ts
65
+ if math.isnan(rtt) or math.isinf(rtt):
66
+ return # Ignore corrupted/invalid packets
67
+ if rtt < 0 or rtt > 5.0:
68
+ return # Ignore this outlier
69
+
70
+ node = self.marilib.gateway.get_node(frame.header.source)
71
+ if node:
72
+ # Update statistics for both the specific node and the whole gateway
73
+ node.latency_stats.add_latency(rtt)
74
+ self.marilib.gateway.latency_stats.add_latency(rtt)
75
+
76
+ except (struct.error, IndexError):
77
+ # Ignore packets that are too short or malformed
78
+ pass
marilib/logger.py ADDED
@@ -0,0 +1,211 @@
1
+ import csv
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from datetime import datetime, timedelta
5
+ from typing import IO, List, Dict
6
+
7
+ from marilib.model import MariGateway, MariNode
8
+
9
+
10
+ @dataclass
11
+ class MetricsLogger:
12
+ """
13
+ A metrics logger that saves statistics to CSV files with log rotation.
14
+ """
15
+
16
+ log_dir_base: str = "logs"
17
+ rotation_interval_minutes: int = 1440 # 1 day
18
+ already_logged_setup_parameters: bool = False
19
+ log_interval_seconds: float = 1.0
20
+ last_log_time: Dict[int, datetime] = field(default_factory=dict)
21
+
22
+ def __post_init__(self):
23
+ """
24
+ Initializes the logger with rotation and setup logging capabilities.
25
+ """
26
+ try:
27
+ self.rotation_interval = timedelta(minutes=self.rotation_interval_minutes)
28
+
29
+ self.start_time = datetime.now()
30
+ self.run_timestamp = self.start_time.strftime("%Y%m%d_%H%M%S")
31
+ self.log_dir = os.path.join(self.log_dir_base, f"run_{self.run_timestamp}")
32
+ os.makedirs(self.log_dir, exist_ok=True)
33
+
34
+ self._gateway_file: IO[str] | None = None
35
+ self._nodes_file: IO[str] | None = None
36
+ self._events_file: IO[str] | None = None
37
+ self._gateway_writer = None
38
+ self._nodes_writer = None
39
+ self._events_writer = None
40
+ self.segment_start_time: datetime | None = None
41
+
42
+ # Open events log file
43
+ events_path = os.path.join(self.log_dir, "log_events.csv")
44
+ self._events_file = open(events_path, "w", newline="", encoding="utf-8")
45
+ self._events_writer = csv.writer(self._events_file)
46
+ self._events_writer.writerow(
47
+ ["timestamp", "gateway_address", "node_address", "event_name", "event_tag"]
48
+ )
49
+
50
+ self._open_new_segment()
51
+ self.active = True
52
+
53
+ except (IOError, OSError) as e:
54
+ print(f"Error: Failed to initialize logger: {e}")
55
+ self.active = False
56
+
57
+ def log_setup_parameters(self, params: Dict[str, any] | None):
58
+ """Creates and writes test setup parameters to metrics_setup.csv."""
59
+ if not params or self.already_logged_setup_parameters:
60
+ return
61
+ # only log setup parameters once
62
+ self.already_logged_setup_parameters = True
63
+
64
+ setup_path = os.path.join(self.log_dir, "metrics_setup.csv")
65
+ with open(setup_path, "w", newline="", encoding="utf-8") as f:
66
+ writer = csv.writer(f)
67
+ writer.writerow(["param", "value"])
68
+ writer.writerow(["start_time", self.start_time.isoformat()])
69
+ for key, value in params.items():
70
+ writer.writerow([key, value])
71
+
72
+ def _open_new_segment(self):
73
+ self._close_segment_files()
74
+
75
+ self.segment_start_time = datetime.now()
76
+ segment_ts = self.segment_start_time.strftime("%H%M%S")
77
+
78
+ gateway_path = os.path.join(self.log_dir, f"gateway_metrics_{segment_ts}.csv")
79
+ nodes_path = os.path.join(self.log_dir, f"node_metrics_{segment_ts}.csv")
80
+
81
+ self._gateway_file = open(gateway_path, "w", newline="", encoding="utf-8")
82
+ self._gateway_writer = csv.writer(self._gateway_file)
83
+ gateway_header = [
84
+ "timestamp",
85
+ "gateway_address",
86
+ "schedule_id",
87
+ "connected_nodes",
88
+ "tx_total",
89
+ "rx_total",
90
+ "tx_rate_1s",
91
+ "rx_rate_1s",
92
+ "avg_latency_ms",
93
+ ]
94
+ self._gateway_writer.writerow(gateway_header)
95
+
96
+ self._nodes_file = open(nodes_path, "w", newline="", encoding="utf-8")
97
+ self._nodes_writer = csv.writer(self._nodes_file)
98
+ nodes_header = [
99
+ "timestamp",
100
+ "gateway_address",
101
+ "node_address",
102
+ "is_alive",
103
+ "tx_total",
104
+ "rx_total",
105
+ "tx_rate_1s",
106
+ "rx_rate_1s",
107
+ "success_rate_30s",
108
+ "success_rate_total",
109
+ "pdr_downlink",
110
+ "pdr_uplink",
111
+ "rssi_dbm_5s",
112
+ "last_latency_ms",
113
+ "avg_latency_ms",
114
+ ]
115
+ self._nodes_writer.writerow(nodes_header)
116
+
117
+ def _check_for_rotation(self):
118
+ if datetime.now() - self.segment_start_time >= self.rotation_interval:
119
+ self._open_new_segment()
120
+
121
+ def _log_common(self):
122
+ if not self.active:
123
+ return False
124
+ self._check_for_rotation()
125
+ return True
126
+
127
+ def log_periodic_metrics(self, gateway: MariGateway, nodes: List[MariNode]):
128
+ last_log_time = self.last_log_time.get(gateway.info.address, self.segment_start_time)
129
+ if datetime.now() - last_log_time >= timedelta(seconds=self.log_interval_seconds):
130
+ self.log_gateway_metrics(gateway)
131
+ self.log_all_nodes_metrics(nodes)
132
+ self.last_log_time[gateway.info.address] = datetime.now()
133
+
134
+ def log_gateway_metrics(self, gateway: MariGateway):
135
+ if not self._log_common() or self._gateway_writer is None:
136
+ return
137
+
138
+ timestamp = datetime.now().isoformat()
139
+ row = [
140
+ timestamp,
141
+ f"0x{gateway.info.address:016X}",
142
+ gateway.info.schedule_id,
143
+ len(gateway.nodes),
144
+ gateway.stats.sent_count(include_test_packets=False),
145
+ gateway.stats.received_count(include_test_packets=False),
146
+ gateway.stats.sent_count(1, include_test_packets=False),
147
+ gateway.stats.received_count(1, include_test_packets=False),
148
+ f"{gateway.latency_stats.avg_ms:.2f}",
149
+ ]
150
+ self._gateway_writer.writerow(row)
151
+
152
+ def log_all_nodes_metrics(self, nodes: List[MariNode]):
153
+ """Writes metrics for all nodes, handling rotation."""
154
+ if not self._log_common() or self._nodes_writer is None:
155
+ return
156
+
157
+ timestamp = datetime.now().isoformat()
158
+ for node in nodes:
159
+ row = [
160
+ timestamp,
161
+ f"0x{node.gateway_address:016X}",
162
+ f"0x{node.address:016X}",
163
+ node.is_alive,
164
+ node.stats.sent_count(include_test_packets=False),
165
+ node.stats.received_count(include_test_packets=False),
166
+ node.stats.sent_count(1, include_test_packets=False),
167
+ node.stats.received_count(1, include_test_packets=False),
168
+ f"{node.stats.success_rate(30):.2%}",
169
+ f"{node.stats.success_rate():.2%}",
170
+ f"{node.pdr_downlink:.2%}",
171
+ f"{node.pdr_uplink:.2%}",
172
+ node.stats.received_rssi_dbm(5),
173
+ f"{node.latency_stats.last_ms:.2f}",
174
+ f"{node.latency_stats.avg_ms:.2f}",
175
+ ]
176
+ self._nodes_writer.writerow(row)
177
+
178
+ def log_event(
179
+ self, gateway_address: int, node_address: int, event_name: str, event_tag: str = ""
180
+ ):
181
+ """Logs an event to the events log file."""
182
+ if not self.active or self._events_writer is None:
183
+ return
184
+
185
+ timestamp = datetime.now().isoformat()
186
+ row = [
187
+ timestamp,
188
+ f"0x{gateway_address:016X}",
189
+ f"0x{node_address:016X}",
190
+ event_name,
191
+ event_tag,
192
+ ]
193
+ self._events_writer.writerow(row)
194
+ if self._events_file:
195
+ self._events_file.flush()
196
+
197
+ def _close_segment_files(self):
198
+ if self._gateway_file and not self._gateway_file.closed:
199
+ self._gateway_file.close()
200
+ if self._nodes_file and not self._nodes_file.closed:
201
+ self._nodes_file.close()
202
+
203
+ def close(self):
204
+ if not self.active:
205
+ return
206
+
207
+ self._close_segment_files()
208
+ if self._events_file and not self._events_file.closed:
209
+ self._events_file.close()
210
+ print(f"\nMetrics saved to: {self.log_dir}")
211
+ self.active = False
@@ -0,0 +1,76 @@
1
+ import dataclasses
2
+ from dataclasses import dataclass
3
+
4
+ from marilib.protocol import Packet, PacketFieldMetadata, PacketType
5
+
6
+ MARI_PROTOCOL_VERSION = 2
7
+ MARI_BROADCAST_ADDRESS = 0xFFFFFFFFFFFFFFFF
8
+ MARI_NET_ID_DEFAULT = 0x0001
9
+
10
+
11
+ @dataclass
12
+ class HeaderStats(Packet):
13
+ """Dataclass that holds MAC header stats."""
14
+
15
+ metadata: list[PacketFieldMetadata] = dataclasses.field(
16
+ default_factory=lambda: [
17
+ PacketFieldMetadata(name="rssi", disp="rssi", length=1),
18
+ ]
19
+ )
20
+ rssi: int = 0
21
+
22
+ @property
23
+ def rssi_dbm(self) -> int:
24
+ if self.rssi > 127:
25
+ return self.rssi - 255
26
+ return self.rssi
27
+
28
+
29
+ @dataclass
30
+ class Header(Packet):
31
+ """Dataclass that holds MAC header fields."""
32
+
33
+ metadata: list[PacketFieldMetadata] = dataclasses.field(
34
+ default_factory=lambda: [
35
+ PacketFieldMetadata(name="version", disp="ver.", length=1),
36
+ PacketFieldMetadata(name="type_", disp="type", length=1),
37
+ PacketFieldMetadata(name="network_id", disp="net", length=2),
38
+ PacketFieldMetadata(name="destination", disp="dst", length=8),
39
+ PacketFieldMetadata(name="source", disp="src", length=8),
40
+ ]
41
+ )
42
+ version: int = MARI_PROTOCOL_VERSION
43
+ type_: int = PacketType.DATA
44
+ network_id: int = MARI_NET_ID_DEFAULT
45
+ destination: int = MARI_BROADCAST_ADDRESS
46
+ source: int = 0x0000000000000000
47
+
48
+ def __repr__(self):
49
+ type_ = PacketType(self.type_).name
50
+ return f"Header(version={self.version}, type_={type_}, network_id=0x{self.network_id:04x}, destination=0x{self.destination:016x}, source=0x{self.source:016x})"
51
+
52
+
53
+ @dataclass
54
+ class Frame:
55
+ """Data class that holds a payload packet."""
56
+
57
+ header: Header = None
58
+ stats: HeaderStats = dataclasses.field(default_factory=HeaderStats)
59
+ payload: bytes = b""
60
+
61
+ def from_bytes(self, bytes_):
62
+ self.header = Header().from_bytes(bytes_[0:20])
63
+ if len(bytes_) > 20:
64
+ self.stats = HeaderStats().from_bytes(bytes_[20:21])
65
+ if len(bytes_) > 21:
66
+ self.payload = bytes_[21:]
67
+ return self
68
+
69
+ def to_bytes(self, byteorder="little") -> bytes:
70
+ header_bytes = self.header.to_bytes(byteorder)
71
+ stats_bytes = self.stats.to_bytes(byteorder)
72
+ return header_bytes + stats_bytes + self.payload
73
+
74
+ def __repr__(self):
75
+ header_no_metadata = dataclasses.replace(self.header, metadata=[])
76
+ return f"Frame(header={header_no_metadata}, payload={self.payload})"
marilib/marilib.py ADDED
@@ -0,0 +1,35 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from marilib.model import MariNode
4
+
5
+
6
+ class MarilibBase(ABC):
7
+ """Base class for Marilib applications."""
8
+
9
+ @abstractmethod
10
+ def update(self):
11
+ """Recurrent bookkeeping. Don't forget to call this periodically on your main loop."""
12
+
13
+ @abstractmethod
14
+ def nodes(self) -> list[MariNode]:
15
+ """Returns all nodes in the network."""
16
+
17
+ @abstractmethod
18
+ def add_node(self, address: int, gateway_address: int = None) -> MariNode | None:
19
+ """Adds a node to the network."""
20
+
21
+ @abstractmethod
22
+ def remove_node(self, address: int) -> MariNode | None:
23
+ """Removes a node from the network."""
24
+
25
+ @abstractmethod
26
+ def send_frame(self, dst: int, payload: bytes):
27
+ """Sends a frame to the network."""
28
+
29
+ @abstractmethod
30
+ def render_tui(self):
31
+ """Renders the TUI."""
32
+
33
+ @abstractmethod
34
+ def close_tui(self):
35
+ """Closes the TUI."""
@@ -0,0 +1,193 @@
1
+ import threading
2
+ from dataclasses import dataclass, field
3
+ from datetime import datetime
4
+ from typing import Any, Callable
5
+
6
+ from marilib.latency import LatencyTester
7
+ from marilib.mari_protocol import Frame, Header
8
+ from marilib.model import (
9
+ EdgeEvent,
10
+ GatewayInfo,
11
+ MariGateway,
12
+ MariNode,
13
+ NodeInfoCloud,
14
+ )
15
+ from marilib.communication_adapter import MQTTAdapter
16
+ from marilib.marilib import MarilibBase
17
+ from marilib.tui_cloud import MarilibTUICloud
18
+
19
+ LOAD_PACKET_PAYLOAD = b"L"
20
+
21
+
22
+ @dataclass
23
+ class MarilibCloud(MarilibBase):
24
+ """
25
+ The MarilibCloud class runs in a computer.
26
+ It is used to communicate with a Mari radio gateway (nRF5340) via MQTT.
27
+ """
28
+
29
+ cb_application: Callable[[EdgeEvent, MariNode | Frame | GatewayInfo], None]
30
+ mqtt_interface: MQTTAdapter
31
+ network_id: int
32
+ tui: MarilibTUICloud | None = None
33
+
34
+ logger: Any | None = None
35
+ gateways: dict[int, MariGateway] = field(default_factory=dict)
36
+ lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
37
+ latency_tester: LatencyTester | None = None
38
+
39
+ started_ts: datetime = field(default_factory=datetime.now)
40
+ last_received_mqtt_data_ts: datetime = field(default_factory=datetime.now)
41
+ main_file: str | None = None
42
+
43
+ def __post_init__(self):
44
+ self.setup_params = {
45
+ "main_file": self.main_file or "unknown",
46
+ "mqtt_host": self.mqtt_interface.host,
47
+ "mqtt_port": self.mqtt_interface.port,
48
+ "network_id": self.network_id_str,
49
+ }
50
+ self.mqtt_interface.set_network_id(self.network_id_str)
51
+ self.mqtt_interface.set_on_data_received(self.on_mqtt_data_received)
52
+ self.mqtt_interface.init()
53
+ if self.logger:
54
+ self.logger.log_setup_parameters(self.setup_params)
55
+
56
+ # ============================ MarilibBase methods =========================
57
+
58
+ def update(self):
59
+ """Recurrent bookkeeping. Don't forget to call this periodically on your main loop."""
60
+ with self.lock:
61
+ # remove dead gateways
62
+ self.gateways = {
63
+ addr: gateway for addr, gateway in self.gateways.items() if gateway.is_alive
64
+ }
65
+ # update each gateway
66
+ for gateway in self.gateways.values():
67
+ gateway.update()
68
+ if self.logger:
69
+ self.logger.log_periodic_metrics(gateway, gateway.nodes)
70
+
71
+ @property
72
+ def nodes(self) -> list[MariNode]:
73
+ return [node for gateway in self.gateways.values() for node in gateway.nodes]
74
+
75
+ def add_node(self, address: int, gateway_address: int = None) -> MariNode | None:
76
+ with self.lock:
77
+ gateway = self.gateways.get(gateway_address)
78
+ if gateway:
79
+ node = gateway.add_node(address)
80
+ return node
81
+ return None
82
+
83
+ def remove_node(self, address: int, gateway_address: int = None) -> MariNode | None:
84
+ with self.lock:
85
+ gateway = self.gateways.get(gateway_address)
86
+ if gateway:
87
+ node = gateway.remove_node(address)
88
+ return node
89
+ return None
90
+
91
+ def send_frame(self, dst: int, payload: bytes):
92
+ """
93
+ Sends a frame to a gateway via MQTT.
94
+ Consists in publishing a message to the /mari/{network_id}/to_edge topic.
95
+ """
96
+ mari_frame = Frame(Header(destination=dst), payload=payload)
97
+
98
+ self.mqtt_interface.send_data_to_edge(
99
+ EdgeEvent.to_bytes(EdgeEvent.NODE_DATA) + mari_frame.to_bytes()
100
+ )
101
+
102
+ def render_tui(self):
103
+ if self.tui:
104
+ self.tui.render(self)
105
+
106
+ def close_tui(self):
107
+ if self.tui:
108
+ self.tui.close()
109
+
110
+ # ============================ MarilibCloud methods =========================
111
+
112
+ @property
113
+ def network_id_str(self) -> str:
114
+ return f"{self.network_id:04X}"
115
+
116
+ # ============================ Callbacks ===================================
117
+
118
+ def handle_mqtt_data(self, data: bytes) -> tuple[bool, EdgeEvent, Any]:
119
+ """
120
+ Handles the MQTT data received from the MQTT broker:
121
+ - parses the event
122
+ - updates node or gateway information
123
+ - returns the event type and data (if any)
124
+ """
125
+
126
+ if len(data) < 1:
127
+ return False, EdgeEvent.UNKNOWN, None
128
+
129
+ self.last_received_mqtt_data_ts = datetime.now()
130
+
131
+ try:
132
+ event_type = EdgeEvent(data[0])
133
+ except ValueError:
134
+ return False, EdgeEvent.UNKNOWN, None
135
+
136
+ try:
137
+ if event_type == EdgeEvent.NODE_JOINED:
138
+ node_info = NodeInfoCloud().from_bytes(data[1:])
139
+ if node := self.add_node(node_info.address, node_info.gateway_address):
140
+ return True, EdgeEvent.NODE_JOINED, node_info
141
+
142
+ elif event_type == EdgeEvent.NODE_LEFT:
143
+ node_info = NodeInfoCloud().from_bytes(data[1:])
144
+ if node := self.remove_node(node_info.address, node_info.gateway_address):
145
+ return True, EdgeEvent.NODE_LEFT, node_info
146
+
147
+ elif event_type == EdgeEvent.NODE_KEEP_ALIVE:
148
+ node_info = NodeInfoCloud().from_bytes(data[1:])
149
+ gateway = self.gateways.get(node_info.gateway_address)
150
+ if gateway:
151
+ gateway.update_node_liveness(node_info.address)
152
+ return True, EdgeEvent.NODE_KEEP_ALIVE, node_info
153
+
154
+ elif event_type == EdgeEvent.GATEWAY_INFO:
155
+ gateway_info = GatewayInfo().from_bytes(data[1:])
156
+ gateway = self.gateways.get(gateway_info.address)
157
+ if not gateway:
158
+ # we are learning about a new gateway, so instantiate it and add it to the list
159
+ gateway = MariGateway(info=gateway_info)
160
+ self.gateways[gateway.info.address] = gateway
161
+ else:
162
+ gateway.set_info(gateway_info)
163
+ return True, EdgeEvent.GATEWAY_INFO, gateway_info
164
+
165
+ elif event_type == EdgeEvent.NODE_DATA:
166
+ frame = Frame().from_bytes(data[1:])
167
+
168
+ gateway_address = frame.header.destination
169
+ node_address = frame.header.source
170
+ gateway = self.gateways.get(gateway_address)
171
+ node = gateway.get_node(node_address)
172
+ if not gateway or not node:
173
+ return False, EdgeEvent.UNKNOWN, None
174
+
175
+ gateway.update_node_liveness(node_address)
176
+ gateway.register_received_frame(frame, is_test_packet=False)
177
+ return True, EdgeEvent.NODE_DATA, frame
178
+
179
+ except Exception as e:
180
+ print(f"Error handling MQTT data: {e}")
181
+
182
+ # fallback result in case of error
183
+ return False, EdgeEvent.UNKNOWN, None
184
+
185
+ def on_mqtt_data_received(self, data: bytes):
186
+ res, event_type, event_data = self.handle_mqtt_data(data)
187
+ if res:
188
+ if self.logger and event_type in [EdgeEvent.NODE_JOINED, EdgeEvent.NODE_LEFT]:
189
+ # TODO: update the logging system to also support GATEWAY_INFO events from multiple gateways
190
+ self.logger.log_event(
191
+ event_data.gateway_address, event_data.address, event_type.name
192
+ )
193
+ self.cb_application(event_type, event_data)