hysn-firecracker-python 1.0.3.post0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
firecracker/network.py ADDED
@@ -0,0 +1,1230 @@
1
+ import os
2
+ import sys
3
+ import ipaddress
4
+ from pyroute2 import IPRoute
5
+ from firecracker.logger import Logger
6
+ from firecracker.utils import run
7
+ from firecracker.config import MicroVMConfig
8
+ from firecracker.exceptions import NetworkError, ConfigurationError
9
+ from ipaddress import IPv4Address, IPv4Network, AddressValueError
10
+
11
+ if os.path.exists("/usr/lib/python3.12/site-packages"):
12
+ sys.path.append("/usr/lib/python3.12/site-packages")
13
+ elif os.path.exists("/usr/lib/python3/dist-packages"):
14
+ sys.path.append("/usr/lib/python3/dist-packages")
15
+
16
+ try:
17
+ from nftables import Nftables
18
+
19
+ NFTABLES_AVAILABLE = True
20
+ except ImportError:
21
+ NFTABLES_AVAILABLE = False
22
+
23
+
24
+ class NetworkManager:
25
+ """Manages network-related operations for Firecracker VMs."""
26
+
27
+ def __init__(self, verbose: bool = False, level: str = "INFO"):
28
+ self._config = MicroVMConfig()
29
+ self._config.verbose = verbose
30
+
31
+ if NFTABLES_AVAILABLE:
32
+ self._nft = Nftables()
33
+ self._nft.set_json_output(True)
34
+ else:
35
+ self._nft = None
36
+
37
+ self._ipr = IPRoute()
38
+ self._logger = Logger(level=level, verbose=verbose)
39
+
40
+ def get_interface_name(self) -> str:
41
+ """Get the name of the network interface.
42
+
43
+ Returns:
44
+ str: Name of the network interface
45
+
46
+ Raises:
47
+ RuntimeError: If unable to determine the interface name
48
+ """
49
+ process = run("ip route | grep default | awk '{print $5}'")
50
+ if process.returncode == 0:
51
+ if self._config.verbose:
52
+ self._logger.debug(f"Default interface name: {process.stdout.strip()}")
53
+
54
+ return process.stdout.strip()
55
+ else:
56
+ raise RuntimeError("Unable to determine the interface name")
57
+
58
+ def get_gateway_ip(self, ip: str) -> str:
59
+ """Derive gateway IP from VMM IP by replacing the last octet with 1 for IPv4,
60
+ or the last segment with 1 for IPv6.
61
+
62
+ Args:
63
+ ip (str): IP address to derive gateway IP from
64
+
65
+ Returns:
66
+ str: Derived gateway IP
67
+
68
+ Raises:
69
+ NetworkError: If IP address is invalid
70
+ """
71
+ try:
72
+ ip_obj = ipaddress.ip_address(ip)
73
+ if isinstance(ip_obj, IPv4Address):
74
+ gateway_ip = IPv4Address((int(ip_obj) & 0xFFFFFF00) | 1)
75
+ elif isinstance(ip_obj, ipaddress.IPv6Address):
76
+ segments = ip_obj.exploded.split(":")
77
+ segments[-1] = "1"
78
+ gateway_ip = ipaddress.IPv6Address(":".join(segments))
79
+ if self._config.verbose:
80
+ self._logger.debug(f"Derived gateway IP: {gateway_ip}")
81
+ else:
82
+ raise NetworkError(f"Unsupported IP address type: {ip}")
83
+
84
+ return str(gateway_ip)
85
+
86
+ except AddressValueError:
87
+ raise NetworkError(f"Invalid IP address format: {ip}")
88
+
89
+ except Exception as e:
90
+ raise NetworkError(f"Failed to derive gateway IP: {str(e)}")
91
+
92
+ def setup(self, tap_name: str, iface_name: str, gateway_ip: str):
93
+ """Setup the network for the Firecracker VM."""
94
+ if not self.check_tap_device(tap_name):
95
+ self.create_tap(tap_name, iface_name, gateway_ip)
96
+
97
+ self.add_nat_rules(tap_name, iface_name)
98
+ self.create_masquerade(iface_name)
99
+
100
+ def find_tap_interface_rules(self, rules, tap_name):
101
+ """Find rules that match the specified tap interface.
102
+
103
+ Args:
104
+ rules (list): List of rules to search through.
105
+ tap_name (str): Name of the tap device to find.
106
+
107
+ Returns:
108
+ list: List of matching rules for the specified tap interface.
109
+ """
110
+ tap_rules = []
111
+ logged_tap_names = set()
112
+
113
+ for item in rules:
114
+ if "rule" in item:
115
+ rule = item["rule"]
116
+ if "expr" in rule:
117
+ for expr in rule["expr"]:
118
+ if (
119
+ "match" in expr
120
+ and "right" in expr["match"]
121
+ and isinstance(expr["match"]["right"], str)
122
+ and tap_name in expr["match"]["right"]
123
+ ):
124
+ if self._config.verbose:
125
+ if tap_name not in logged_tap_names:
126
+ self._logger.debug(
127
+ f"Found matching rule for {tap_name} with handle {rule['handle']}"
128
+ )
129
+ logged_tap_names.add(tap_name)
130
+ tap_rules.append(
131
+ {
132
+ "handle": rule["handle"],
133
+ "chain": rule["chain"],
134
+ "interface": expr["match"]["right"],
135
+ }
136
+ )
137
+
138
+ return tap_rules
139
+
140
+ def check_tap_device(self, tap_device_name: str) -> bool:
141
+ """Check if the tap device exists in the system using pyroute2.
142
+
143
+ Args:
144
+ tap_device_name (str): Name of the tap device to check.
145
+
146
+ Returns:
147
+ bool: True if the device exists, False otherwise.
148
+
149
+ Raises:
150
+ NetworkError: If checking the tap device fails.
151
+ """
152
+ try:
153
+ links = self._ipr.link_lookup(ifname=tap_device_name)
154
+ if not bool(links):
155
+ return False
156
+ else:
157
+ return True
158
+
159
+ except Exception as e:
160
+ raise NetworkError(
161
+ f"Failed to check tap device {tap_device_name}: {str(e)}"
162
+ )
163
+
164
+ def is_nftables_available(self) -> bool:
165
+ """Check if nftables functionality is available.
166
+
167
+ Returns:
168
+ bool: True if nftables is available, False otherwise
169
+ """
170
+ import os
171
+
172
+ if "PYTEST_CURRENT_TEST" in os.environ:
173
+ return False
174
+ return NFTABLES_AVAILABLE and self._nft is not None
175
+
176
+ def _safe_nft_cmd(self, cmd, json_cmd=True):
177
+ """Safely execute nftables command.
178
+
179
+ Args:
180
+ cmd: Command to execute
181
+ json_cmd (bool): Whether to use json_cmd or cmd
182
+
183
+ Returns:
184
+ tuple: (return_code, output, error) or (None, None, None) if nftables not available
185
+ """
186
+ if not self.is_nftables_available():
187
+ if self._config.verbose:
188
+ self._logger.warn("Nftables not available, skipping command")
189
+ return None, None, None
190
+
191
+ try:
192
+ if json_cmd:
193
+ return self._nft.json_cmd(cmd)
194
+ else:
195
+ return self._nft.cmd(cmd)
196
+ except Exception as e:
197
+ if self._config.verbose:
198
+ self._logger.error(f"Nftables command failed: {str(e)}")
199
+ return 1, None, str(e)
200
+
201
+ def add_nat_rules(self, tap_name: str, iface_name: str):
202
+ """Create network rules using nftables Python module.
203
+
204
+ Args:
205
+ tap_name (str): Name of the tap device.
206
+ iface_name (str): Name of the interface to be used.
207
+
208
+ Raises:
209
+ NetworkError: If adding NAT forwarding rule fails.
210
+ """
211
+ if not self.is_nftables_available():
212
+ if self._config.verbose:
213
+ self._logger.warn("Nftables not available, skipping NAT rules")
214
+ return
215
+
216
+ try:
217
+ rules = [
218
+ {
219
+ "nftables": [
220
+ {"add": {"table": {"family": "ip", "name": "nat"}}},
221
+ {
222
+ "add": {
223
+ "chain": {
224
+ "family": "ip",
225
+ "table": "nat",
226
+ "name": "POSTROUTING",
227
+ "type": "nat",
228
+ "hook": "postrouting",
229
+ "priority": 100,
230
+ "policy": "accept",
231
+ }
232
+ }
233
+ },
234
+ {"add": {"table": {"family": "ip", "name": "filter"}}},
235
+ {
236
+ "add": {
237
+ "chain": {
238
+ "family": "ip",
239
+ "table": "filter",
240
+ "name": "FORWARD",
241
+ "type": "filter",
242
+ "hook": "forward",
243
+ "priority": 0,
244
+ "policy": "accept",
245
+ }
246
+ }
247
+ },
248
+ {
249
+ "add": {
250
+ "rule": {
251
+ "family": "ip",
252
+ "table": "filter",
253
+ "chain": "FORWARD",
254
+ "expr": [
255
+ {
256
+ "match": {
257
+ "left": {"meta": {"key": "iifname"}},
258
+ "op": "==",
259
+ "right": tap_name,
260
+ }
261
+ },
262
+ {
263
+ "match": {
264
+ "left": {"meta": {"key": "oifname"}},
265
+ "op": "==",
266
+ "right": iface_name,
267
+ }
268
+ },
269
+ {"counter": {"packets": 0, "bytes": 0}},
270
+ {"accept": None},
271
+ ],
272
+ }
273
+ }
274
+ },
275
+ {
276
+ "add": {
277
+ "rule": {
278
+ "family": "ip",
279
+ "table": "filter",
280
+ "chain": "FORWARD",
281
+ "expr": [
282
+ {
283
+ "match": {
284
+ "left": {"meta": {"key": "iifname"}},
285
+ "op": "==",
286
+ "right": iface_name,
287
+ }
288
+ },
289
+ {
290
+ "match": {
291
+ "left": {"meta": {"key": "oifname"}},
292
+ "op": "==",
293
+ "right": tap_name,
294
+ }
295
+ },
296
+ {
297
+ "match": {
298
+ "left": {"ct": {"key": "state"}},
299
+ "op": "in",
300
+ "right": ["established", "related"],
301
+ }
302
+ },
303
+ {"counter": {"packets": 0, "bytes": 0}},
304
+ {"accept": None},
305
+ ],
306
+ }
307
+ }
308
+ },
309
+ ]
310
+ }
311
+ ]
312
+
313
+ for rule in rules:
314
+ rc, output, error = self._nft.json_cmd(rule)
315
+ if self._config.verbose:
316
+ self._logger.info("Added NAT forwarding rule")
317
+ self._logger.debug(f"NAT forwarding rule: {output}")
318
+
319
+ if rc != 0 and "File exists" not in str(error):
320
+ raise NetworkError(f"Failed to add NAT forwarding rule: {error}")
321
+
322
+ except Exception as e:
323
+ raise NetworkError(f"Failed to add NAT forwarding rule: {str(e)}")
324
+
325
+ def get_nat_rules(self):
326
+ """Get all NAT rules from nftables.
327
+
328
+ Returns:
329
+ list: List of NAT rules.
330
+
331
+ Raises:
332
+ NetworkError: If getting NAT rules fails.
333
+ """
334
+ try:
335
+ rule = {"nftables": [{"list": {"table": {"family": "ip", "name": "nat"}}}]}
336
+ rc, output, error = self._safe_nft_cmd(rule)
337
+
338
+ if rc is None: # Nftables not available
339
+ return []
340
+
341
+ if rc != 0:
342
+ raise NetworkError(f"Failed to get NAT rules: {error}")
343
+
344
+ if output and "nftables" in output:
345
+ return output["nftables"]
346
+ else:
347
+ return []
348
+
349
+ except Exception as e:
350
+ raise NetworkError(f"Failed to get NAT rules: {str(e)}")
351
+
352
+ def get_masquerade_handle(self):
353
+ """
354
+ Get the handle value of a masquerade rule for the specified machine ID.
355
+
356
+ Args:
357
+ id (str): Machine ID to match in the rule comment.
358
+
359
+ Returns:
360
+ int: The handle value if found, None otherwise.
361
+ """
362
+ list_cmd = {"nftables": [{"list": {"table": {"family": "ip", "name": "nat"}}}]}
363
+ output = self._nft.json_cmd(list_cmd)
364
+
365
+ if not output[0]:
366
+ result = output[1]["nftables"]
367
+ expected_comment = "microVM outbound NAT"
368
+
369
+ for item in result:
370
+ if "rule" not in item:
371
+ continue
372
+
373
+ rule = item["rule"]
374
+ if rule.get("chain") != "POSTROUTING":
375
+ continue
376
+
377
+ comment = rule.get("comment", "")
378
+ has_masquerade = False
379
+
380
+ # Check for masquerade action
381
+ for expr in rule.get("expr", []):
382
+ if "masquerade" in expr:
383
+ has_masquerade = True
384
+ break
385
+
386
+ if comment == expected_comment and has_masquerade:
387
+ if self._config.verbose:
388
+ self._logger.debug(
389
+ f"Found masquerade rule with handle {rule.get('handle')}"
390
+ )
391
+ return rule.get("handle")
392
+
393
+ return None
394
+
395
+ def create_masquerade(self, iface_name: str):
396
+ """
397
+ Ensure a masquerade rule exists for the specified interface.
398
+ Creates it if it doesn't exist, returns the handle if it does.
399
+
400
+ Args:
401
+ id (str): Machine ID for the rule comment.
402
+ iface_name (str): The interface name.
403
+
404
+ Returns:
405
+ int: The handle value of the rule.
406
+ """
407
+ try:
408
+ handle = self.get_masquerade_handle()
409
+ if handle is not None:
410
+ if self._config.verbose:
411
+ self._logger.debug("Masquerade rule already exists")
412
+ return True
413
+
414
+ add_cmd = {
415
+ "nftables": [
416
+ {
417
+ "add": {
418
+ "rule": {
419
+ "family": "ip",
420
+ "table": "nat",
421
+ "chain": "POSTROUTING",
422
+ "comment": "microVM outbound NAT",
423
+ "expr": [
424
+ {
425
+ "match": {
426
+ "op": "==",
427
+ "left": {"meta": {"key": "oifname"}},
428
+ "right": iface_name,
429
+ }
430
+ },
431
+ {"counter": {"packets": 0, "bytes": 0}},
432
+ {"masquerade": None},
433
+ ],
434
+ }
435
+ }
436
+ }
437
+ ]
438
+ }
439
+
440
+ result = self._nft.json_cmd(add_cmd)
441
+ if not result[0]:
442
+ if self._config.verbose:
443
+ self._logger.info("Created masquerade rule")
444
+ return True
445
+ else:
446
+ return False
447
+
448
+ except Exception as e:
449
+ raise NetworkError(f"Failed to create masquerade rule: {str(e)}")
450
+
451
+ def get_port_forward_handles(
452
+ self, host_ip: str, host_port: int, dest_ip: str, dest_port: int
453
+ ):
454
+ """Get port forwarding rules from the nat table.
455
+
456
+ Checks for both:
457
+ - PREROUTING rules that forward traffic from host_ip:host_port to dest_ip:dest_port
458
+ - POSTROUTING rules that handle return traffic from dest_ip (masquerade)
459
+
460
+ Args:
461
+ host_ip (str): IP address to forward from.
462
+ host_port (int): Port to forward.
463
+ dest_ip (str): IP address to forward to.
464
+ dest_port (int): Port to forward to.
465
+
466
+ Returns:
467
+ dict: Dictionary containing handles for prerouting and postrouting rules.
468
+
469
+ Raises:
470
+ NetworkError: If retrieving nftables rules fails.
471
+ """
472
+ list_cmd = {"nftables": [{"list": {"table": {"family": "ip", "name": "nat"}}}]}
473
+
474
+ try:
475
+ output = self._nft.json_cmd(list_cmd)
476
+ result = output[1]["nftables"]
477
+ rules = {}
478
+
479
+ for item in result:
480
+ if "rule" not in item:
481
+ continue
482
+
483
+ rule = item["rule"]
484
+ chain = rule.get(
485
+ "chain", ""
486
+ ).upper() # Normalize chain name to uppercase
487
+
488
+ if (
489
+ rule.get("family") == "ip"
490
+ and rule.get("table") == "nat"
491
+ and chain == "PREROUTING"
492
+ ):
493
+ expr = rule.get("expr", [])
494
+
495
+ has_daddr_match = False
496
+ has_dport_match = False
497
+ has_correct_dnat = False
498
+
499
+ for e in expr:
500
+ if (
501
+ "match" in e
502
+ and e["match"]["op"] == "=="
503
+ and "payload" in e["match"]["left"]
504
+ and e["match"]["left"]["payload"]["field"] == "daddr"
505
+ and e["match"]["right"] == host_ip
506
+ ):
507
+ has_daddr_match = True
508
+
509
+ if (
510
+ "match" in e
511
+ and e["match"]["op"] == "=="
512
+ and "payload" in e["match"]["left"]
513
+ and e["match"]["left"]["payload"]["field"] == "dport"
514
+ and e["match"]["right"] == host_port
515
+ ):
516
+ has_dport_match = True
517
+
518
+ if (
519
+ "dnat" in e
520
+ and e["dnat"]["addr"] == dest_ip
521
+ and e["dnat"]["port"] == dest_port
522
+ ):
523
+ has_correct_dnat = True
524
+ if self._config.verbose:
525
+ self._logger.info(
526
+ f"Prerouting rule: {dest_ip}:{dest_port}"
527
+ )
528
+
529
+ if has_daddr_match and has_dport_match and has_correct_dnat:
530
+ if self._config.verbose:
531
+ self._logger.debug(
532
+ f"Found matching prerouting port forward rule {rule}"
533
+ )
534
+ self._logger.info(
535
+ f"Found prerouting rule with handle {rule['handle']}"
536
+ )
537
+ rules["prerouting"] = rule["handle"]
538
+
539
+ # Check for POSTROUTING rules (for outgoing traffic)
540
+ elif (
541
+ rule.get("family") == "ip"
542
+ and rule.get("table") == "nat"
543
+ and chain == "POSTROUTING"
544
+ ):
545
+ expr = rule.get("expr", [])
546
+ has_saddr_match = False
547
+ has_masquerade = False
548
+
549
+ for e in expr:
550
+ if (
551
+ "match" in e
552
+ and e["match"]["op"] == "=="
553
+ and "payload" in e["match"]["left"]
554
+ and e["match"]["left"]["payload"]["field"] == "saddr"
555
+ ):
556
+ has_saddr_match = True
557
+
558
+ if "masquerade" in e:
559
+ has_masquerade = True
560
+
561
+ # Note: This function is not currently used, but if it were, it would need an 'id' parameter
562
+ # For now, we'll just check for masquerade rules without machine_id matching
563
+ if has_saddr_match and has_masquerade:
564
+ if self._config.verbose:
565
+ self._logger.debug(
566
+ f"Found matching postrouting masquerade rule {rule}"
567
+ )
568
+ self._logger.info(
569
+ f"Found postrouting rule with handle {rule['handle']}"
570
+ )
571
+ rules["postrouting"] = rule["handle"]
572
+
573
+ if not rules and self._config.verbose:
574
+ self._logger.info("No port forwarding rules found")
575
+
576
+ return rules
577
+
578
+ except Exception as e:
579
+ raise NetworkError(f"Failed to get nftables rules: {str(e)}")
580
+
581
+ def get_port_forward_by_comment(self, id: str, host_port: int, dest_port: int):
582
+ """Get port forwarding rules by matching the comment pattern.
583
+
584
+ Args:
585
+ id (str): Machine ID to search for
586
+ host_port (int): Host port to search for
587
+ dest_port (int): Destination port to search for
588
+
589
+ Returns:
590
+ dict: Dictionary containing handles for prerouting rules only.
591
+
592
+ Raises:
593
+ NetworkError: If retrieving nftables rules fails.
594
+ """
595
+ list_cmd = {"nftables": [{"list": {"table": {"family": "ip", "name": "nat"}}}]}
596
+
597
+ try:
598
+ output = self._nft.json_cmd(list_cmd)
599
+ result = output[1]["nftables"]
600
+ rules = {}
601
+
602
+ prerouting_comment = (
603
+ f"machine_id={id} host_port={host_port} vm_port={dest_port}"
604
+ )
605
+
606
+ for item in result:
607
+ if "rule" not in item:
608
+ continue
609
+
610
+ rule = item["rule"]
611
+ chain = rule.get(
612
+ "chain", ""
613
+ ).upper() # Normalize chain name to uppercase
614
+ comment = rule.get("comment", "")
615
+
616
+ # Check for PREROUTING rules with matching comment only
617
+ if (
618
+ rule.get("family") == "ip"
619
+ and rule.get("table") == "nat"
620
+ and chain == "PREROUTING"
621
+ ):
622
+ if comment == prerouting_comment:
623
+ if self._config.verbose:
624
+ self._logger.info(
625
+ f"Found prerouting rule with matching comment: {comment}"
626
+ )
627
+ self._logger.debug(f"Rule details: {rule}")
628
+ rules["prerouting"] = rule["handle"]
629
+
630
+ if not rules and self._config.verbose:
631
+ self._logger.info(
632
+ f"No port forwarding rules found for machine_id={id} host_port={host_port} vm_port={dest_port}"
633
+ )
634
+
635
+ return rules
636
+
637
+ except Exception as e:
638
+ raise NetworkError(f"Failed to get nftables rules: {str(e)}")
639
+
640
+ def _check_postrouting_exists(self, id: str) -> bool:
641
+ """Check if a POSTROUTING rule already exists for the given machine ID.
642
+
643
+ Args:
644
+ id (str): Machine ID to check for
645
+
646
+ Returns:
647
+ bool: True if POSTROUTING rule exists, False otherwise
648
+ """
649
+ try:
650
+ list_cmd = {
651
+ "nftables": [{"list": {"table": {"family": "ip", "name": "nat"}}}]
652
+ }
653
+ output = self._nft.json_cmd(list_cmd)
654
+ result = output[1]["nftables"]
655
+
656
+ postrouting_comment = f"machine_id={id}"
657
+
658
+ for item in result:
659
+ if "rule" not in item:
660
+ continue
661
+
662
+ rule = item["rule"]
663
+ chain = rule.get("chain", "").upper()
664
+ comment = rule.get("comment", "")
665
+
666
+ if (
667
+ rule.get("family") == "ip"
668
+ and rule.get("table") == "nat"
669
+ and chain == "POSTROUTING"
670
+ and comment == postrouting_comment
671
+ ):
672
+ if self._config.verbose:
673
+ self._logger.debug(
674
+ f"Found existing POSTROUTING rule for machine_id={id}"
675
+ )
676
+ return True
677
+
678
+ return False
679
+
680
+ except Exception as e:
681
+ if self._config.verbose:
682
+ self._logger.warn(
683
+ f"Failed to check for existing POSTROUTING rule: {str(e)}"
684
+ )
685
+ return False
686
+
687
+ def add_port_forward(
688
+ self,
689
+ id: str,
690
+ host_ip: str,
691
+ host_port: int,
692
+ dest_ip: str,
693
+ dest_port: int,
694
+ protocol: str = "tcp",
695
+ ):
696
+ """Port forward a port to a new IP and port.
697
+
698
+ Args:
699
+ host_ip (str): IP address to forward from.
700
+ host_port (int): Port to forward.
701
+ dest_ip (str): IP address to forward to.
702
+ dest_port (int): Port to forward to.
703
+ protocol (str): Protocol to forward (default: "tcp").
704
+
705
+ Raises:
706
+ NetworkError: If adding nftables port forwarding rule fails.
707
+ """
708
+ import ipaddress
709
+
710
+ # Detect IP family and prefix length
711
+ try:
712
+ ip = ipaddress.ip_address(host_ip)
713
+ if isinstance(ip, ipaddress.IPv4Address):
714
+ family = "ip"
715
+ prefix_len = 32
716
+ else:
717
+ family = "ip6"
718
+ prefix_len = 128
719
+ except ValueError:
720
+ raise NetworkError(f"Invalid IP address: {host_ip}")
721
+
722
+ # First check if the PREROUTING rule already exists
723
+ existing_rules = self.get_port_forward_by_comment(id, host_port, dest_port)
724
+ if existing_rules:
725
+ if self._config.verbose:
726
+ self._logger.info("Port forwarding rules already exist")
727
+ return True
728
+
729
+ # Check if POSTROUTING rule already exists
730
+ postrouting_exists = self._check_postrouting_exists(id)
731
+
732
+ # Create the rules
733
+ rules = {
734
+ "nftables": [
735
+ {"add": {"table": {"family": family, "name": "nat"}}},
736
+ {
737
+ "add": {
738
+ "chain": {
739
+ "family": family,
740
+ "table": "nat",
741
+ "name": "PREROUTING",
742
+ "type": "nat",
743
+ "hook": "prerouting",
744
+ "prio": -100,
745
+ "policy": "accept",
746
+ }
747
+ }
748
+ },
749
+ ]
750
+ }
751
+
752
+ # Only add POSTROUTING chain if it doesn't exist
753
+ if not postrouting_exists:
754
+ rules["nftables"].append(
755
+ {
756
+ "add": {
757
+ "chain": {
758
+ "family": family,
759
+ "table": "nat",
760
+ "name": "POSTROUTING",
761
+ "type": "nat",
762
+ "hook": "postrouting",
763
+ "prio": 100,
764
+ "policy": "accept",
765
+ }
766
+ }
767
+ }
768
+ )
769
+
770
+ # Add PREROUTING rule
771
+ rules["nftables"].append(
772
+ {
773
+ "add": {
774
+ "rule": {
775
+ "family": family,
776
+ "table": "nat",
777
+ "chain": "PREROUTING",
778
+ "comment": f"machine_id={id} host_port={host_port} vm_port={dest_port}",
779
+ "expr": [
780
+ {
781
+ "match": {
782
+ "op": "==",
783
+ "left": {
784
+ "payload": {
785
+ "protocol": family,
786
+ "field": "daddr",
787
+ }
788
+ },
789
+ "right": host_ip,
790
+ }
791
+ },
792
+ {
793
+ "match": {
794
+ "op": "==",
795
+ "left": {
796
+ "payload": {
797
+ "protocol": protocol,
798
+ "field": "dport",
799
+ }
800
+ },
801
+ "right": host_port,
802
+ }
803
+ },
804
+ {"dnat": {"addr": dest_ip, "port": dest_port}},
805
+ ],
806
+ }
807
+ }
808
+ }
809
+ )
810
+
811
+ # Only add POSTROUTING rule if it doesn't already exist
812
+ if not postrouting_exists:
813
+ rules["nftables"].append(
814
+ {
815
+ "add": {
816
+ "rule": {
817
+ "family": family,
818
+ "table": "nat",
819
+ "chain": "POSTROUTING",
820
+ "comment": f"machine_id={id}",
821
+ "expr": [
822
+ {
823
+ "match": {
824
+ "op": "==",
825
+ "left": {
826
+ "payload": {
827
+ "protocol": family,
828
+ "field": "saddr",
829
+ }
830
+ },
831
+ "right": {
832
+ "prefix": {
833
+ "addr": dest_ip,
834
+ "len": prefix_len,
835
+ }
836
+ },
837
+ }
838
+ },
839
+ {"masquerade": None},
840
+ ],
841
+ }
842
+ }
843
+ }
844
+ )
845
+
846
+ try:
847
+ for rule in rules["nftables"]:
848
+ rc, _, error = self._nft.json_cmd({"nftables": [rule]})
849
+ if rc != 0:
850
+ error_str = str(error)
851
+ ignore_errors = [
852
+ "File exists",
853
+ "already exists",
854
+ ]
855
+ if not any(err in error_str for err in ignore_errors):
856
+ raise NetworkError(
857
+ f"Failed to add port forwarding rule: {error}"
858
+ )
859
+
860
+ if self._config.verbose:
861
+ self._logger.info(
862
+ f"Added port forwarding rule: {host_ip}:{host_port} -> {dest_ip}:{dest_port}"
863
+ )
864
+
865
+ except Exception as e:
866
+ raise NetworkError(f"Failed to add port forwarding rules: {str(e)}")
867
+
868
+ def delete_rule(self, rule):
869
+ """Delete a single nftables rule.
870
+
871
+ Args:
872
+ rule (dict): Rule to delete.
873
+
874
+ Returns:
875
+ bool: True if the rule was successfully deleted, False otherwise.
876
+
877
+ Raises:
878
+ NetworkError: If deleting the rule fails.
879
+ """
880
+ cmd = f"delete rule filter {rule['chain']} handle {rule['handle']}"
881
+ rc, output, error = self._nft.cmd(cmd)
882
+
883
+ try:
884
+ if self._config.verbose:
885
+ if rc == 0:
886
+ self._logger.debug(f"Rule with handle {rule['handle']} deleted")
887
+ else:
888
+ self._logger.error(
889
+ f"Error deleting rule with handle {rule['handle']}: {error}"
890
+ )
891
+
892
+ return rc == 0
893
+
894
+ except Exception as e:
895
+ raise NetworkError(f"Failed to delete rule: {str(e)}")
896
+
897
+ def delete_nat_rules(self, tap_name):
898
+ """Delete all nftables rules associated with the specified tap interface.
899
+
900
+ Args:
901
+ tap_name (str): Name of the tap device to delete rules for.
902
+ """
903
+ try:
904
+ rules = self.get_nat_rules()
905
+ tap_rules = self.find_tap_interface_rules(rules, tap_name)
906
+ if self._config.verbose:
907
+ self._logger.debug(f"Found {len(tap_rules)} rules for {tap_name}")
908
+
909
+ for rule in tap_rules:
910
+ self.delete_rule(rule)
911
+ if self._config.verbose:
912
+ self._logger.debug(f"Deleted rule with handle {rule['handle']}")
913
+ self._logger.info("Deleted NAT rules")
914
+
915
+ except Exception as e:
916
+ raise NetworkError(f"Failed to delete NAT rules: {str(e)}")
917
+
918
+ def delete_masquerade(self):
919
+ """Delete masquerade rules for the specified interface.
920
+
921
+ Raises:
922
+ NetworkError: If deleting masquerade rules fails.
923
+ """
924
+ try:
925
+ handle = self.get_masquerade_handle()
926
+ if handle is not None:
927
+ process = run(
928
+ f"nft delete rule nat POSTROUTING handle {handle}",
929
+ capture_output=True,
930
+ timeout=5,
931
+ )
932
+ if process.returncode == 0:
933
+ if self._config.verbose:
934
+ self._logger.debug(
935
+ f"Deleted masquerade rule with handle {handle}"
936
+ )
937
+ self._logger.info("Deleted masquerade rules")
938
+ else:
939
+ if self._config.verbose:
940
+ self._logger.warn(
941
+ f"Error deleting masquerade rule with handle {handle}: {process.stderr.decode()}"
942
+ )
943
+
944
+ except Exception as e:
945
+ raise NetworkError(f"Failed to delete masquerade rule: {str(e)}")
946
+
947
+ def delete_port_forward(self, id: str, host_port: int, dest_port: int):
948
+ """Delete port forwarding rules.
949
+
950
+ Args:
951
+ id (str): Machine ID for which port forwarding is being deleted.
952
+ host_port (int): Host port being forwarded.
953
+ dest_port (int): Destination port being forwarded to.
954
+
955
+ Raises:
956
+ NetworkError: If deleting port forwarding rules fails.
957
+ """
958
+ if not isinstance(host_port, int) or host_port < 1 or host_port > 65535:
959
+ raise ValueError(
960
+ f"Invalid host port number: {host_port}. Must be between 1 and 65535."
961
+ )
962
+
963
+ if not id:
964
+ raise ValueError("id cannot be empty")
965
+
966
+ try:
967
+ output = self._nft.json_cmd(
968
+ {"nftables": [{"list": {"table": {"family": "ip", "name": "nat"}}}]}
969
+ )
970
+ rules = output[1]["nftables"]
971
+
972
+ for item in rules:
973
+ if "rule" not in item:
974
+ continue
975
+
976
+ rule = item["rule"]
977
+ comment = rule.get("comment", "")
978
+
979
+ comment_matches = (
980
+ f"machine_id={id} host_port={host_port} vm_port={dest_port}"
981
+ in comment
982
+ )
983
+
984
+ if comment_matches:
985
+ chain = rule.get("chain", "").upper()
986
+ handle = rule["handle"]
987
+
988
+ cmd = f"delete rule nat {chain} handle {handle}"
989
+ rc, _, error = self._nft.cmd(cmd)
990
+
991
+ if self._config.verbose:
992
+ if rc == 0:
993
+ self._logger.debug(
994
+ f"{chain} rule with handle {handle} deleted"
995
+ )
996
+ else:
997
+ self._logger.warn(
998
+ f"Error deleting {chain} rule with handle {handle}: {error}"
999
+ )
1000
+
1001
+ if self._config.verbose:
1002
+ self._logger.info(
1003
+ f"Deleted port forwarding rule for {id} with host port {host_port}"
1004
+ )
1005
+
1006
+ except Exception as e:
1007
+ raise NetworkError(f"Failed to delete port forward rules: {str(e)}")
1008
+
1009
+ def delete_all_port_forward(self, id: str):
1010
+ """Delete all port forwarding rules for a given machine ID.
1011
+
1012
+ Args:
1013
+ id (str): Machine ID to search for and delete all associated port forwarding rules.
1014
+
1015
+ Raises:
1016
+ NetworkError: If deleting port forwarding rules fails.
1017
+ """
1018
+ list_cmd = {"nftables": [{"list": {"table": {"family": "ip", "name": "nat"}}}]}
1019
+
1020
+ try:
1021
+ output = self._nft.json_cmd(list_cmd)
1022
+ result = output[1]["nftables"]
1023
+ rules_to_delete = {}
1024
+
1025
+ for item in result:
1026
+ if "rule" not in item:
1027
+ continue
1028
+
1029
+ rule = item["rule"]
1030
+ chain = rule.get("chain", "").upper()
1031
+ comment = rule.get("comment", "")
1032
+
1033
+ if comment and f"machine_id={id}" in comment:
1034
+ if chain == "PREROUTING":
1035
+ if "prerouting" not in rules_to_delete:
1036
+ rules_to_delete["prerouting"] = []
1037
+ rules_to_delete["prerouting"].append(rule["handle"])
1038
+ elif chain == "POSTROUTING":
1039
+ if "postrouting" not in rules_to_delete:
1040
+ rules_to_delete["postrouting"] = []
1041
+ rules_to_delete["postrouting"].append(rule["handle"])
1042
+
1043
+ if not rules_to_delete:
1044
+ if self._config.verbose:
1045
+ self._logger.info("No port forwarding rules found")
1046
+ return
1047
+
1048
+ for chain, handles in rules_to_delete.items():
1049
+ for handle in handles:
1050
+ cmd = f"delete rule nat {chain.upper()} handle {handle}"
1051
+ rc, output, error = self._nft.cmd(cmd)
1052
+
1053
+ if self._config.verbose:
1054
+ if rc == 0:
1055
+ self._logger.debug(
1056
+ f"{chain} rule with handle {handle} deleted"
1057
+ )
1058
+ self._logger.info("Deleted port forwarding rules")
1059
+ else:
1060
+ self._logger.warn(
1061
+ f"Error deleting {chain} rule with handle {handle}: {error}"
1062
+ )
1063
+
1064
+ if self._config.verbose:
1065
+ self._logger.info(f"Deleted all port forwarding rules for {id}")
1066
+
1067
+ except Exception as e:
1068
+ raise NetworkError(f"Failed to delete port forward rules: {str(e)}")
1069
+
1070
+ def detect_cidr_conflict(self, ip_addr: str, prefix_len: int = 24) -> bool:
1071
+ """Check if the given IP address and prefix length conflict with existing interfaces.
1072
+
1073
+ Args:
1074
+ ip_addr (str): IP address to check for conflicts
1075
+ prefix_len (int): Network prefix length (default 24 for /24 networks)
1076
+
1077
+ Returns:
1078
+ bool: True if a conflict exists, False otherwise
1079
+
1080
+ Raises:
1081
+ NetworkError: If the IP address format is invalid
1082
+ """
1083
+ try:
1084
+ new_network = IPv4Network(f"{ip_addr}/{prefix_len}", strict=False)
1085
+
1086
+ ifaces = self._ipr.get_links()
1087
+
1088
+ for iface in ifaces:
1089
+ idx = iface["index"]
1090
+ addresses = self._ipr.get_addr(index=idx)
1091
+
1092
+ for addr in addresses:
1093
+ for attr_name, attr_value in addr.get("attrs", []):
1094
+ if attr_name == "IFA_ADDRESS":
1095
+ if ":" in attr_value:
1096
+ continue
1097
+
1098
+ existing_prefix = addr.get("prefixlen", 24)
1099
+ existing_network = IPv4Network(
1100
+ f"{attr_value}/{existing_prefix}", strict=False
1101
+ )
1102
+
1103
+ if new_network.overlaps(existing_network):
1104
+ if self._config.verbose:
1105
+ self._logger.warn(
1106
+ f"CIDR conflict detected: {new_network} "
1107
+ f"overlaps with existing {existing_network}"
1108
+ )
1109
+ return False
1110
+ return True
1111
+
1112
+ except (AddressValueError, ValueError) as e:
1113
+ raise NetworkError(f"Invalid IP address format: {str(e)}")
1114
+
1115
+ except Exception as e:
1116
+ raise NetworkError(f"Failed to check CIDR conflicts: {str(e)}")
1117
+
1118
+ def suggest_non_conflicting_ip(
1119
+ self, preferred_ip: str, prefix_len: int = 24
1120
+ ) -> str:
1121
+ """Suggest a non-conflicting IP address based on the preferred IP.
1122
+
1123
+ Args:
1124
+ preferred_ip (str): Preferred IP address
1125
+ prefix_len (int): Network prefix length
1126
+
1127
+ Returns:
1128
+ str: A non-conflicting IP address
1129
+
1130
+ Raises:
1131
+ NetworkError: If unable to find non-conflicting IP
1132
+ """
1133
+ try:
1134
+ ip_obj = ipaddress.ip_address(preferred_ip)
1135
+
1136
+ for i in range(10):
1137
+ if isinstance(ip_obj, IPv4Address):
1138
+ octets = str(ip_obj).split(".")
1139
+ new_third_octet = (int(octets[2]) + i + 1) % 256
1140
+ new_ip = f"{octets[0]}.{octets[1]}.{new_third_octet}.{octets[3]}"
1141
+
1142
+ if not self.detect_cidr_conflict(new_ip, prefix_len):
1143
+ self._logger.debug(f"Suggested non-conflicting IP: {new_ip}")
1144
+ return new_ip
1145
+
1146
+ raise NetworkError("Unable to find a non-conflicting IP address")
1147
+
1148
+ except Exception as e:
1149
+ raise NetworkError(f"Failed to suggest non-conflicting IP: {str(e)}")
1150
+
1151
+ def close(self):
1152
+ """Close network manager resources and release file descriptors."""
1153
+ try:
1154
+ if self._nft:
1155
+ self._nft = None
1156
+ if self._ipr:
1157
+ self._ipr.close()
1158
+ self._ipr = None
1159
+ except Exception:
1160
+ pass
1161
+
1162
+ def create_tap(
1163
+ self, tap_name: str = None, iface_name: str = None, gateway_ip: str = None
1164
+ ) -> None:
1165
+ """Create and configure a new tap device using pyroute2.
1166
+
1167
+ Args:
1168
+ iface_name (str, optional): Name of the interface for firewall rules.
1169
+ name (str, optional): Name for the new tap device.
1170
+ gateway_ip (str, optional): IP address to be assigned to the tap device.
1171
+
1172
+ Raises:
1173
+ NetworkError: If tap device creation or configuration fails.
1174
+ ConfigurationError: If required parameters are missing.
1175
+ """
1176
+ if not tap_name or (iface_name and len(iface_name) > 16):
1177
+ if not tap_name:
1178
+ raise ConfigurationError("TAP device name is required")
1179
+ else:
1180
+ # pyroute2 issue: https://github.com/svinota/pyroute2/issues/452#issuecomment-363702389
1181
+ raise ValueError("Interface name must not exceed 16 characters")
1182
+
1183
+ try:
1184
+ self._ipr.link("add", ifname=tap_name, kind="tuntap", mode="tap")
1185
+ idx = self._ipr.link_lookup(ifname=tap_name)[0]
1186
+ if gateway_ip:
1187
+ self._ipr.addr("add", index=idx, address=gateway_ip, prefixlen=24)
1188
+
1189
+ self._ipr.link("set", index=idx, state="up")
1190
+
1191
+ if self._config.verbose:
1192
+ self._logger.debug(f"Created TAP device {tap_name}")
1193
+
1194
+ except Exception as e:
1195
+ self.cleanup(tap_name)
1196
+ raise NetworkError(f"Failed to create TAP device {tap_name}: {str(e)}")
1197
+
1198
+ def delete_tap(self, name: str) -> None:
1199
+ """Delete a tap device using pyroute2.
1200
+
1201
+ Args:
1202
+ name (str): Name of the tap device to clean up.
1203
+ """
1204
+ try:
1205
+ if self.check_tap_device(name):
1206
+ idx = self._ipr.link_lookup(ifname=name)[0]
1207
+ self._ipr.link("del", index=idx)
1208
+ if self._config.verbose:
1209
+ self._logger.info(f"Removed tap device {name}")
1210
+ return True
1211
+
1212
+ except Exception as e:
1213
+ raise NetworkError(f"Failed to delete tap device {name}: {str(e)}")
1214
+
1215
+ def cleanup(self, tap_device: str):
1216
+ """Clean up network resources including TAP device and firewall rules.
1217
+
1218
+ Args:
1219
+ tap_device (str): Name of the tap device to clean up.
1220
+ """
1221
+ try:
1222
+ self.delete_nat_rules(tap_device)
1223
+ machine_id = tap_device[4:]
1224
+
1225
+ self.delete_masquerade()
1226
+ self.delete_all_port_forward(machine_id)
1227
+ self.delete_tap(tap_device)
1228
+
1229
+ except Exception as e:
1230
+ raise NetworkError(f"Failed to cleanup network resources: {str(e)}")