mcast-tools 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ """mcast-tools — friendly multicast send/receive tools for lab environments."""
2
+
3
+ __version__ = "0.1.0"
mcast_tools/common.py ADDED
@@ -0,0 +1,513 @@
1
+ """Shared utilities for mcast-tools.
2
+
3
+ Provides:
4
+ - Custom packet header definition (struct format, encode/decode)
5
+ - Multicast address validation
6
+ - Route lookup via `ip -j route get`
7
+ - IGMP version sysctl management (read/force/restore)
8
+ - Human-readable formatting helpers
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import ipaddress
13
+ import json
14
+ import os
15
+ import re
16
+ import struct
17
+ import subprocess
18
+ import sys
19
+ from dataclasses import dataclass
20
+ from pathlib import Path
21
+ from typing import Optional
22
+
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Packet format
26
+ # ---------------------------------------------------------------------------
27
+ #
28
+ # All fields network byte order (big-endian).
29
+ #
30
+ # uint32 magic 0x4D434153 ("MCAS")
31
+ # uint32 version protocol version (currently 1)
32
+ # uint64 seq sequence number, starts at 0
33
+ # uint64 send_ns sender monotonic timestamp (nanoseconds)
34
+ # uint32 sender_id random ID per sender invocation (detect restart)
35
+ # uint16 payload_len bytes of payload following header
36
+ # uint8 orig_ttl TTL the sender set on outgoing socket
37
+ # uint8 orig_dscp DSCP the sender set on outgoing socket
38
+ #
39
+ # Total: 32 bytes. Followed by `payload_len` bytes of padding.
40
+
41
+ PACKET_MAGIC = 0x4D434153 # "MCAS"
42
+ PROTOCOL_VERSION = 1
43
+ HEADER_STRUCT = struct.Struct("!IIQQIHBB")
44
+ HEADER_SIZE = HEADER_STRUCT.size # 32
45
+
46
+ assert HEADER_SIZE == 32, f"Header size drift: {HEADER_SIZE}"
47
+
48
+
49
+ @dataclass
50
+ class PacketHeader:
51
+ """Decoded mcast-tools packet header."""
52
+
53
+ magic: int
54
+ version: int
55
+ seq: int
56
+ send_ns: int
57
+ sender_id: int
58
+ payload_len: int
59
+ orig_ttl: int
60
+ orig_dscp: int
61
+
62
+ @classmethod
63
+ def decode(cls, buf: bytes) -> Optional["PacketHeader"]:
64
+ """Decode a header from a buffer. Returns None if magic/version mismatch."""
65
+ if len(buf) < HEADER_SIZE:
66
+ return None
67
+ magic, version, seq, send_ns, sender_id, payload_len, orig_ttl, orig_dscp = (
68
+ HEADER_STRUCT.unpack_from(buf, 0)
69
+ )
70
+ if magic != PACKET_MAGIC:
71
+ return None
72
+ if version != PROTOCOL_VERSION:
73
+ return None
74
+ return cls(
75
+ magic=magic,
76
+ version=version,
77
+ seq=seq,
78
+ send_ns=send_ns,
79
+ sender_id=sender_id,
80
+ payload_len=payload_len,
81
+ orig_ttl=orig_ttl,
82
+ orig_dscp=orig_dscp,
83
+ )
84
+
85
+ def encode(self) -> bytes:
86
+ return HEADER_STRUCT.pack(
87
+ self.magic,
88
+ self.version,
89
+ self.seq,
90
+ self.send_ns,
91
+ self.sender_id,
92
+ self.payload_len,
93
+ self.orig_ttl,
94
+ self.orig_dscp,
95
+ )
96
+
97
+
98
+ def make_packet(
99
+ seq: int,
100
+ send_ns: int,
101
+ sender_id: int,
102
+ orig_ttl: int,
103
+ orig_dscp: int,
104
+ datagram_size: int,
105
+ ) -> bytes:
106
+ """Build a full datagram of `datagram_size` bytes (header + padding)."""
107
+ if datagram_size < HEADER_SIZE:
108
+ raise ValueError(
109
+ f"datagram_size {datagram_size} too small (min {HEADER_SIZE})"
110
+ )
111
+ payload_len = datagram_size - HEADER_SIZE
112
+ header = HEADER_STRUCT.pack(
113
+ PACKET_MAGIC,
114
+ PROTOCOL_VERSION,
115
+ seq,
116
+ send_ns,
117
+ sender_id,
118
+ payload_len,
119
+ orig_ttl,
120
+ orig_dscp,
121
+ )
122
+ # Deterministic padding pattern — easier to eyeball in a packet capture.
123
+ padding = b"\xA5" * payload_len
124
+ return header + padding
125
+
126
+
127
+ # ---------------------------------------------------------------------------
128
+ # Multicast address validation
129
+ # ---------------------------------------------------------------------------
130
+
131
+ def validate_multicast_group(group: str) -> ipaddress.IPv4Address:
132
+ """Validate `group` is an IPv4 multicast address. Returns the address object.
133
+
134
+ Raises ValueError with a friendly message on failure.
135
+ """
136
+ try:
137
+ addr = ipaddress.ip_address(group)
138
+ except ValueError:
139
+ raise ValueError(f"'{group}' is not a valid IP address")
140
+ if not isinstance(addr, ipaddress.IPv4Address):
141
+ raise ValueError(
142
+ f"'{group}' is IPv6 — mcast-tools v0.1 supports IPv4 only"
143
+ )
144
+ if not addr.is_multicast:
145
+ raise ValueError(
146
+ f"'{group}' is not a multicast address "
147
+ f"(IPv4 multicast range is 224.0.0.0/4)"
148
+ )
149
+ return addr
150
+
151
+
152
+ def classify_multicast_group(addr: ipaddress.IPv4Address) -> str:
153
+ """Return a short human-readable scope/classification for an mcast group."""
154
+ if addr in ipaddress.IPv4Network("224.0.0.0/24"):
155
+ return "link-local (not routed)"
156
+ if addr in ipaddress.IPv4Network("232.0.0.0/8"):
157
+ return "SSM (source-specific)"
158
+ if addr in ipaddress.IPv4Network("233.0.0.0/8"):
159
+ return "GLOP"
160
+ if addr in ipaddress.IPv4Network("239.0.0.0/8"):
161
+ return "admin-scoped"
162
+ return "any-source multicast"
163
+
164
+
165
+ def is_ssm_group(addr: ipaddress.IPv4Address) -> bool:
166
+ return addr in ipaddress.IPv4Network("232.0.0.0/8")
167
+
168
+
169
+ def validate_source(source: str) -> ipaddress.IPv4Address:
170
+ try:
171
+ addr = ipaddress.ip_address(source)
172
+ except ValueError:
173
+ raise ValueError(f"'{source}' is not a valid IP address")
174
+ if not isinstance(addr, ipaddress.IPv4Address):
175
+ raise ValueError("Source must be an IPv4 address")
176
+ if addr.is_multicast:
177
+ raise ValueError(
178
+ f"Source '{source}' is a multicast address — sources must be unicast"
179
+ )
180
+ return addr
181
+
182
+
183
+ # ---------------------------------------------------------------------------
184
+ # Route lookup — figure out which interface a multicast group will egress on
185
+ # ---------------------------------------------------------------------------
186
+
187
+ @dataclass
188
+ class RouteInfo:
189
+ interface: str
190
+ src_addr: Optional[str] # preferred source IP
191
+
192
+ @classmethod
193
+ def for_destination(cls, dest: str) -> Optional["RouteInfo"]:
194
+ """Run `ip -j route get <dest>` and parse the result. Returns None on failure."""
195
+ try:
196
+ result = subprocess.run(
197
+ ["ip", "-j", "route", "get", dest],
198
+ capture_output=True,
199
+ text=True,
200
+ timeout=3,
201
+ )
202
+ except (FileNotFoundError, subprocess.TimeoutExpired):
203
+ return None
204
+ if result.returncode != 0:
205
+ return None
206
+ try:
207
+ data = json.loads(result.stdout)
208
+ except json.JSONDecodeError:
209
+ return None
210
+ if not data:
211
+ return None
212
+ entry = data[0]
213
+ iface = entry.get("dev")
214
+ if not iface:
215
+ return None
216
+ return cls(interface=iface, src_addr=entry.get("prefsrc"))
217
+
218
+
219
+ # ---------------------------------------------------------------------------
220
+ # IGMP version control
221
+ # ---------------------------------------------------------------------------
222
+
223
+ def igmp_force_path(interface: str) -> Path:
224
+ return Path(f"/proc/sys/net/ipv4/conf/{interface}/force_igmp_version")
225
+
226
+
227
+ def read_force_igmp_version(interface: str) -> Optional[int]:
228
+ """Read current force_igmp_version sysctl. Returns None if path missing."""
229
+ p = igmp_force_path(interface)
230
+ try:
231
+ return int(p.read_text().strip())
232
+ except (FileNotFoundError, ValueError, PermissionError):
233
+ return None
234
+
235
+
236
+ def write_force_igmp_version(interface: str, version: int) -> bool:
237
+ """Force IGMP version on an interface. Requires root. Returns True on success.
238
+
239
+ version=0 means "auto" (kernel default — usually v3).
240
+ """
241
+ if version not in (0, 1, 2, 3):
242
+ raise ValueError(f"Invalid IGMP version: {version}")
243
+ p = igmp_force_path(interface)
244
+ try:
245
+ p.write_text(f"{version}\n")
246
+ return True
247
+ except (FileNotFoundError, PermissionError):
248
+ return False
249
+
250
+
251
+ def parse_proc_net_igmp() -> dict[str, dict]:
252
+ """Parse /proc/net/igmp into a structured dict keyed by interface name.
253
+
254
+ The file looks like::
255
+
256
+ Idx\tDevice : Count Querier\tGroup Users Timer\tReporter
257
+ 1\tlo : 1 V3
258
+ \t\t\t\t010000E0 1 0:00000000\t\t0
259
+ 2\teth0 : 1 V3
260
+ \t\t\t\t010000E0 1 0:00000000\t\t0
261
+
262
+ Despite the header column naming, per-interface rows really carry only
263
+ "count" and the operational IGMP version. Group rows are indented and
264
+ contain hex address (host byte order), users, timer, and reporter.
265
+
266
+ Returns:
267
+ {
268
+ "eth0": {
269
+ "version": "V3",
270
+ "groups": [
271
+ {"address": "224.0.0.1", "users": 1, "timer": "0:00000000",
272
+ "reporter": "0"},
273
+ ...
274
+ ]
275
+ },
276
+ ...
277
+ }
278
+ """
279
+ try:
280
+ raw = Path("/proc/net/igmp").read_text()
281
+ except (FileNotFoundError, PermissionError):
282
+ return {}
283
+
284
+ interfaces: dict[str, dict] = {}
285
+ current_iface: Optional[str] = None
286
+
287
+ for line in raw.splitlines():
288
+ if not line.strip():
289
+ continue
290
+ # Header line begins with literal "Idx"
291
+ if line.startswith("Idx"):
292
+ continue
293
+ # Per-interface lines start with the index digit (no leading whitespace)
294
+ # and contain a colon separating "idx device" from "count version".
295
+ if not (line.startswith("\t") or line.startswith(" ")):
296
+ if ":" not in line:
297
+ continue
298
+ left, right = line.split(":", 1)
299
+ left_parts = re.split(r"\s+", left.strip())
300
+ right_parts = re.split(r"\s+", right.strip())
301
+ # left: [idx, device] right: [count, version]
302
+ if len(left_parts) >= 2 and len(right_parts) >= 2:
303
+ iface = left_parts[1]
304
+ version = right_parts[-1]
305
+ current_iface = iface
306
+ interfaces[iface] = {
307
+ "version": version,
308
+ "groups": [],
309
+ }
310
+ continue
311
+ # Group lines are indented. Format: hex_addr users timer reporter
312
+ if current_iface is None:
313
+ continue
314
+ parts = re.split(r"\s+", line.strip())
315
+ if len(parts) < 4:
316
+ continue
317
+ hex_addr = parts[0]
318
+ try:
319
+ users = int(parts[1])
320
+ except ValueError:
321
+ continue
322
+ timer = parts[2]
323
+ reporter = parts[3]
324
+ # Hex address is in host byte order (little-endian on x86_64); reverse
325
+ # the bytes to produce the conventional dotted-quad representation.
326
+ try:
327
+ raw_bytes = bytes.fromhex(hex_addr)
328
+ if len(raw_bytes) == 4:
329
+ dotted = ".".join(str(b) for b in reversed(raw_bytes))
330
+ else:
331
+ dotted = hex_addr
332
+ except ValueError:
333
+ dotted = hex_addr
334
+ interfaces[current_iface]["groups"].append(
335
+ {
336
+ "address": dotted,
337
+ "users": users,
338
+ "timer": timer,
339
+ "reporter": reporter,
340
+ }
341
+ )
342
+ return interfaces
343
+
344
+
345
+ def get_interface_igmp_version(interface: str) -> Optional[str]:
346
+ """Read the operational IGMP version for an interface from /proc/net/igmp."""
347
+ info = parse_proc_net_igmp()
348
+ iface_info = info.get(interface)
349
+ if iface_info:
350
+ return iface_info.get("version")
351
+ return None
352
+
353
+
354
+ # ---------------------------------------------------------------------------
355
+ # Rate parsing
356
+ # ---------------------------------------------------------------------------
357
+
358
+ _RATE_RE = re.compile(
359
+ r"^\s*(?P<num>\d+(?:\.\d+)?)\s*(?P<unit>pps|bps|kbps|mbps|gbps|kpps|mpps)?\s*$",
360
+ re.IGNORECASE,
361
+ )
362
+
363
+
364
+ def parse_rate(rate_str: str, datagram_size: int) -> float:
365
+ """Parse a rate string like '25pps' or '1Mbps' to packets-per-second (float).
366
+
367
+ `datagram_size` is used to convert bps -> pps. We use the UDP payload size,
368
+ NOT including IP+UDP header overhead (28 bytes), since that's the user's
369
+ mental model — they're configuring application-layer bytes.
370
+ """
371
+ m = _RATE_RE.match(rate_str)
372
+ if not m:
373
+ raise ValueError(
374
+ f"Could not parse rate '{rate_str}'. "
375
+ f"Examples: '25pps', '500pps', '1Mbps', '100kbps'"
376
+ )
377
+ num = float(m.group("num"))
378
+ unit = (m.group("unit") or "pps").lower()
379
+
380
+ if unit == "pps":
381
+ return num
382
+ if unit == "kpps":
383
+ return num * 1000
384
+ if unit == "mpps":
385
+ return num * 1_000_000
386
+ if unit == "bps":
387
+ bps = num
388
+ elif unit == "kbps":
389
+ bps = num * 1000
390
+ elif unit == "mbps":
391
+ bps = num * 1_000_000
392
+ elif unit == "gbps":
393
+ bps = num * 1_000_000_000
394
+ else:
395
+ raise ValueError(f"Unknown rate unit: {unit}")
396
+ bytes_per_packet = datagram_size
397
+ pps = bps / (bytes_per_packet * 8)
398
+ if pps < 1:
399
+ raise ValueError(
400
+ f"Rate '{rate_str}' yields {pps:.3f} pps at {datagram_size}-byte "
401
+ f"datagrams — too slow to be meaningful (min 1 pps)"
402
+ )
403
+ return pps
404
+
405
+
406
+ # ---------------------------------------------------------------------------
407
+ # DSCP / TOS helpers
408
+ # ---------------------------------------------------------------------------
409
+
410
+ DSCP_NAMES = {
411
+ "be": 0, "cs0": 0,
412
+ "cs1": 8, "af11": 10, "af12": 12, "af13": 14,
413
+ "cs2": 16, "af21": 18, "af22": 20, "af23": 22,
414
+ "cs3": 24, "af31": 26, "af32": 28, "af33": 30,
415
+ "cs4": 32, "af41": 34, "af42": 36, "af43": 38,
416
+ "cs5": 40, "ef": 46,
417
+ "cs6": 48, "cs7": 56,
418
+ }
419
+
420
+ DSCP_NAME_LOOKUP = {v: k.upper() for k, v in DSCP_NAMES.items()}
421
+
422
+
423
+ def parse_dscp(s: str) -> int:
424
+ """Parse a DSCP value — accepts numeric (0-63) or symbolic name (ef, af41, ...)."""
425
+ s_low = s.strip().lower()
426
+ if s_low in DSCP_NAMES:
427
+ return DSCP_NAMES[s_low]
428
+ try:
429
+ v = int(s_low, 0)
430
+ except ValueError:
431
+ raise ValueError(
432
+ f"Unknown DSCP value '{s}'. Use a number 0-63 or a name "
433
+ f"(be, cs1-cs7, af11-af43, ef)."
434
+ )
435
+ if not 0 <= v <= 63:
436
+ raise ValueError(f"DSCP {v} out of range (0-63)")
437
+ return v
438
+
439
+
440
+ def dscp_label(dscp: int) -> str:
441
+ """Render a DSCP value as e.g. 'EF (46)' or '0 (BE)'."""
442
+ name = DSCP_NAME_LOOKUP.get(dscp)
443
+ if name:
444
+ return f"{name} ({dscp})"
445
+ return f"{dscp}"
446
+
447
+
448
+ def dscp_from_tos_byte(tos: int) -> int:
449
+ """Extract DSCP from a TOS byte (upper 6 bits)."""
450
+ return (tos & 0xFC) >> 2
451
+
452
+
453
+ def tos_byte_from_dscp(dscp: int) -> int:
454
+ return (dscp & 0x3F) << 2
455
+
456
+
457
+ # ---------------------------------------------------------------------------
458
+ # Formatting helpers
459
+ # ---------------------------------------------------------------------------
460
+
461
+ def fmt_bytes(n: int) -> str:
462
+ """Format a byte count as a human-readable string with SI prefixes (KB/MB/GB)."""
463
+ if n < 1000:
464
+ return f"{n} B"
465
+ if n < 1_000_000:
466
+ return f"{n/1000:.2f} KB"
467
+ if n < 1_000_000_000:
468
+ return f"{n/1_000_000:.2f} MB"
469
+ return f"{n/1_000_000_000:.2f} GB"
470
+
471
+
472
+ def fmt_bps(bps: float) -> str:
473
+ if bps < 1000:
474
+ return f"{bps:.0f} bps"
475
+ if bps < 1_000_000:
476
+ return f"{bps/1000:.2f} kbps"
477
+ if bps < 1_000_000_000:
478
+ return f"{bps/1_000_000:.2f} Mbps"
479
+ return f"{bps/1_000_000_000:.2f} Gbps"
480
+
481
+
482
+ def fmt_duration(seconds: float) -> str:
483
+ """Format seconds as HH:MM:SS."""
484
+ if seconds < 0:
485
+ seconds = 0
486
+ total = int(seconds)
487
+ h = total // 3600
488
+ m = (total % 3600) // 60
489
+ s = total % 60
490
+ return f"{h:02d}:{m:02d}:{s:02d}"
491
+
492
+
493
+ def fmt_pps(pps: float) -> str:
494
+ if pps < 1000:
495
+ return f"{pps:.2f} pps"
496
+ if pps < 1_000_000:
497
+ return f"{pps/1000:.2f} kpps"
498
+ return f"{pps/1_000_000:.2f} Mpps"
499
+
500
+
501
+ # ---------------------------------------------------------------------------
502
+ # Privilege checks
503
+ # ---------------------------------------------------------------------------
504
+
505
+ def require_root_or_warn(action: str) -> bool:
506
+ """Return True if running as root, False with a stderr warning otherwise."""
507
+ if os.geteuid() == 0:
508
+ return True
509
+ print(
510
+ f"WARNING: not running as root — {action} may fail or be ignored.",
511
+ file=sys.stderr,
512
+ )
513
+ return False