primitive 0.2.66__py3-none-any.whl → 0.2.70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of primitive might be problematic. Click here for more details.

@@ -1,10 +1,12 @@
1
1
  from pathlib import Path
2
-
2
+ from subprocess import PIPE, Popen
3
+ import json
3
4
  import requests
4
5
  import serial
5
6
  from loguru import logger
6
7
  from paramiko import SSHClient
7
8
  from typing import TypedDict
9
+ import re
8
10
 
9
11
  from primitive.messaging.provider import MESSAGE_TYPES
10
12
  from primitive.utils.actions import BaseAction
@@ -39,6 +41,11 @@ def mac_address_manufacturer_style_to_ieee(mac: str) -> str:
39
41
  return ":".join(mac[i : i + 2] for i in range(0, 12, 2))
40
42
 
41
43
 
44
+ def natural_interface_key(s):
45
+ # extract numbers after "Ethernet" or subports
46
+ return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", s)]
47
+
48
+
42
49
  class SwitchConnectionInfo(TypedDict):
43
50
  vendor: str
44
51
  hostname: str
@@ -46,6 +53,12 @@ class SwitchConnectionInfo(TypedDict):
46
53
  password: str
47
54
 
48
55
 
56
+ class MacAddressEntry(TypedDict):
57
+ ip_address: str | None
58
+ mac_address: str
59
+ vlan: str
60
+
61
+
49
62
  class Network(BaseAction):
50
63
  def __init__(self, *args, **kwargs) -> None:
51
64
  super().__init__(*args, **kwargs)
@@ -137,6 +150,8 @@ class Network(BaseAction):
137
150
  return None
138
151
 
139
152
  def get_switch_info(self):
153
+ if self.switch_connection_info is None:
154
+ self.primitive.hardware.get_and_set_switch_info()
140
155
  if self.is_switch_api_enabled():
141
156
  switch_info = self.get_switch_info_via_api()
142
157
  if switch_info:
@@ -145,26 +160,46 @@ class Network(BaseAction):
145
160
  return None
146
161
 
147
162
  def get_interfaces_info(self):
163
+ if self.switch_connection_info is None:
164
+ self.primitive.hardware.get_and_set_switch_info()
148
165
  if self.is_switch_api_enabled():
149
166
  interfaces_info = self.get_interfaces_via_api()
150
167
  mac_address_info = self.get_mac_address_info_via_api()
168
+
151
169
  ip_arp_table_info = self.get_ip_arp_table_via_api()
170
+ controllers_neighbors = self.get_ip_arp_table_via_ip_command()
152
171
 
153
172
  if interfaces_info and mac_address_info and ip_arp_table_info:
154
173
  for interface, mac_info in mac_address_info.items():
155
- if interface in interfaces_info.get("interfaces", {}):
156
- interfaces_info["interfaces"][interface]["mac_address"] = (
157
- mac_info.get("macAddress", "")
158
- )
159
- if interface in ip_arp_table_info:
160
- for ip_arp in ip_arp_table_info[interface]:
161
- if (
162
- interfaces_info["interfaces"][interface]["ip_address"]
163
- == ""
164
- ):
165
- interfaces_info["interfaces"][interface][
174
+ if interface in interfaces_info:
175
+ mac_addresses: dict[str, MacAddressEntry] = {}
176
+ for entry in mac_info:
177
+ mac_addresses[entry.get("macAddress", "")] = {
178
+ "mac_address": entry.get("macAddress", ""),
179
+ "ip_address": None,
180
+ "vlan": entry.get("vlanId", ""),
181
+ }
182
+
183
+ for neighbor in controllers_neighbors:
184
+ if neighbor.get("lladdr", "") in mac_addresses:
185
+ mac_addresses[neighbor.get("lladdr", "")][
166
186
  "ip_address"
167
- ] = ip_arp.get("ip_address", "")
187
+ ] = neighbor.get("dst", None)
188
+
189
+ interfaces_info[interface]["mac_addresses"] = mac_addresses
190
+
191
+ if interface in ip_arp_table_info:
192
+ for ip_arp in ip_arp_table_info[interface]:
193
+ for mac_address_entry in interfaces_info[interface][
194
+ "mac_addresses"
195
+ ].values():
196
+ if (
197
+ ip_arp.get("mac_address", "")
198
+ in mac_address_entry["mac_address"]
199
+ ):
200
+ mac_address_entry["ip_address"] = ip_arp.get(
201
+ "ip_address", None
202
+ )
168
203
 
169
204
  return interfaces_info
170
205
 
@@ -245,25 +280,27 @@ class Network(BaseAction):
245
280
  # }
246
281
  arista_interfaces_info = response.get("result", [])[0]
247
282
  formatted_interfaces_info = {
248
- "interfaces": {
249
- k: {
250
- "interface_name": k,
251
- "interface_type": v.get("interfaceType", ""),
252
- "link_status": v.get("linkStatus", ""),
253
- "line_protocol_status": v.get("lineProtocolStatus", ""),
254
- "mac_address": "",
255
- "ip_address": "",
256
- }
257
- for k, v in dict(
258
- sorted(
259
- arista_interfaces_info.get(
260
- "interfaceStatuses", {}
261
- ).items()
262
- )
263
- ).items()
264
- },
265
- # "raw_output": arista_interfaces_info,
283
+ k: {
284
+ "interface_name": k,
285
+ "interface_type": v.get("interfaceType", ""),
286
+ "link_status": v.get("linkStatus", ""),
287
+ "line_protocol_status": v.get("lineProtocolStatus", ""),
288
+ "mac_addresses": {},
289
+ }
290
+ for k, v in dict(
291
+ sorted(
292
+ arista_interfaces_info.get("interfaceStatuses", {}).items()
293
+ )
294
+ ).items()
266
295
  }
296
+ if formatted_interfaces_info:
297
+ formatted_interfaces_info = {
298
+ k: formatted_interfaces_info[k]
299
+ for k in sorted(
300
+ formatted_interfaces_info.keys(), key=natural_interface_key
301
+ )
302
+ }
303
+
267
304
  return formatted_interfaces_info
268
305
 
269
306
  def get_mac_address_info_via_api(self):
@@ -297,9 +334,14 @@ class Network(BaseAction):
297
334
  .get("tableEntries", [])
298
335
  )
299
336
  table_entries.sort(key=lambda x: x["lastMove"])
337
+
300
338
  for entry in table_entries:
301
339
  if entry.get("interface") not in interface_to_mac_address_info:
302
- interface_to_mac_address_info[entry.get("interface")] = entry
340
+ interface_to_mac_address_info[entry.get("interface")] = [entry]
341
+ else:
342
+ interface_to_mac_address_info[entry.get("interface")].append(
343
+ entry
344
+ )
303
345
 
304
346
  return interface_to_mac_address_info
305
347
 
@@ -329,6 +371,13 @@ class Network(BaseAction):
329
371
  )
330
372
  return ip_to_mac_address_info
331
373
 
374
+ def get_ip_arp_table_via_ip_command(self):
375
+ command = "ip --json neigh show"
376
+ ip_result = None
377
+ with Popen(command.split(" "), stdout=PIPE) as process:
378
+ ip_result = json.loads(process.stdout.read().decode("utf-8"))
379
+ return ip_result
380
+
332
381
  def serial_connect(self):
333
382
  self.ser = serial.Serial()
334
383
  self.ser.port = self.switch_tty_name
@@ -408,10 +457,10 @@ class Network(BaseAction):
408
457
 
409
458
  return False
410
459
 
411
- def push_switch_and_interfaces_info(self):
460
+ def push_switch_and_interfaces_info(self, interfaces_info: dict | None = None):
412
461
  if self.primitive.messaging.ready and self.switch_connection_info is not None:
413
462
  switch_info = self.get_switch_info()
414
- interfaces_info = self.get_interfaces_info()
463
+ interfaces_info = interfaces_info or self.get_interfaces_info()
415
464
 
416
465
  message = {"switch_info": {}, "interfaces_info": {}}
417
466
  if switch_info:
@@ -28,13 +28,24 @@ def switch(context):
28
28
  print_result(message=message, context=context)
29
29
 
30
30
 
31
- @cli.command("ports")
31
+ @cli.command("interfaces")
32
32
  @click.pass_context
33
- def ports(context):
34
- """Ports"""
33
+ @click.option(
34
+ "--push",
35
+ is_flag=True,
36
+ show_default=True,
37
+ default=False,
38
+ help="Push current interface info.",
39
+ )
40
+ def interfaces(context, push: bool = False):
41
+ """Interfaces"""
35
42
  primitive: Primitive = context.obj.get("PRIMITIVE")
36
- ports_info = primitive.network.get_interfaces_info()
43
+ interfaces_info = primitive.network.get_interfaces_info()
44
+ if push:
45
+ primitive.network.push_switch_and_interfaces_info(
46
+ interfaces_info=interfaces_info
47
+ )
37
48
  if context.obj["JSON"]:
38
- print_result(message=ports_info, context=context)
49
+ print_result(message=interfaces_info, context=context)
39
50
  else:
40
- render_ports_table(ports_info.get("interfaces"))
51
+ render_ports_table(interfaces_info)
@@ -0,0 +1,126 @@
1
+ from loguru import logger
2
+ import paramiko
3
+ import socket
4
+ import time
5
+ from paramiko import SSHClient
6
+
7
+
8
+ def test_ssh_connection(hostname, username, password=None, key_filename=None, port=22):
9
+ """
10
+ Tests an SSH connection to a remote host.
11
+
12
+ Args:
13
+ hostname (str): The hostname or IP address of the remote SSH server.
14
+ username (str): The username for authentication.
15
+ password (str, optional): The password for authentication. Defaults to None.
16
+ key_filename (str, optional): Path to the private key file for authentication. Defaults to None.
17
+ port (int, optional): The SSH port. Defaults to 22.
18
+
19
+ Returns:
20
+ bool: True if the connection is successful, False otherwise.
21
+ """
22
+ ssh_client = paramiko.SSHClient()
23
+ ssh_client.set_missing_host_key_policy(
24
+ paramiko.AutoAddPolicy()
25
+ ) # Auto-add new host keys
26
+
27
+ try:
28
+ if password:
29
+ ssh_client.connect(
30
+ hostname=hostname, port=port, username=username, password=password
31
+ )
32
+ elif key_filename:
33
+ ssh_client.connect(
34
+ hostname=hostname,
35
+ port=port,
36
+ username=username,
37
+ key_filename=key_filename,
38
+ )
39
+ else:
40
+ print(
41
+ "Error: Either password or key_filename must be provided for authentication."
42
+ )
43
+ return False
44
+
45
+ print(f"Successfully connected to {hostname} as {username}")
46
+ return True
47
+ except paramiko.AuthenticationException:
48
+ print(f"Authentication failed for {username} on {hostname}")
49
+ return False
50
+ except paramiko.SSHException as exception:
51
+ print(f"SSH error connecting to {hostname}: {exception}")
52
+ return False
53
+ except socket.error as exception:
54
+ print(f"Socket error connecting to {hostname}: {exception}")
55
+ return False
56
+ except Exception as exception:
57
+ print(f"An unexpected error occurred: {exception}")
58
+ return False
59
+ finally:
60
+ ssh_client.close()
61
+
62
+
63
+ TEN_MINUTES = 60 * 10
64
+
65
+
66
+ def wait_for_ssh(
67
+ hostname, username, password=None, key_filename=None, port=22, timeout=TEN_MINUTES
68
+ ):
69
+ """
70
+ Waits until an SSH connection to a remote host can be established.
71
+
72
+ Args:
73
+ hostname (str): The hostname or IP address of the remote SSH server.
74
+ username (str): The username for authentication.
75
+ password (str, optional): The password for authentication. Defaults to None.
76
+ key_filename (str, optional): Path to the private key file for authentication. Defaults to None.
77
+ port (int, optional): The SSH port. Defaults to 22.
78
+ timeout (int, optional): Maximum time to wait in seconds. Defaults to 300.
79
+
80
+ Returns:
81
+ bool: True if the connection is successful within the timeout, False otherwise.
82
+ """
83
+
84
+ start_time = time.time()
85
+ while time.time() - start_time < timeout:
86
+ if test_ssh_connection(
87
+ hostname, username, password=password, key_filename=key_filename, port=port
88
+ ):
89
+ return True
90
+ logger.debug(f"Waiting for SSH to become available on {hostname}...")
91
+ time.sleep(10)
92
+
93
+ logger.warning(
94
+ f"Timeout reached: Unable to connect to {hostname} via SSH within {timeout} seconds."
95
+ )
96
+ return False
97
+
98
+
99
+ def run_command(
100
+ hostname,
101
+ username,
102
+ command: str,
103
+ password=None,
104
+ key_filename=None,
105
+ port=22,
106
+ ):
107
+ ssh_client = SSHClient()
108
+ ssh_client.load_system_host_keys()
109
+ ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
110
+ ssh_client.connect(
111
+ hostname=hostname,
112
+ port=port,
113
+ username=username,
114
+ password=password,
115
+ key_filename=key_filename,
116
+ )
117
+ stdin, stdout, stderr = ssh_client.exec_command(command)
118
+
119
+ stdout_string = stdout.read().decode("utf-8").rstrip("\n")
120
+ stderr_string = stderr.read().decode("utf-8").rstrip("\n")
121
+ if stdout_string != b"":
122
+ logger.info(stdout_string)
123
+ if stderr_string != b"":
124
+ logger.error(stderr_string)
125
+
126
+ ssh_client.close()
primitive/network/ui.py CHANGED
@@ -8,12 +8,18 @@ def render_ports_table(ports_dict) -> None:
8
8
  table = Table(show_header=True, header_style="bold #FFA800")
9
9
  table.add_column("Port")
10
10
  table.add_column("Status")
11
- table.add_column("MAC Address")
12
- table.add_column("IP Address")
11
+ table.add_column("MAC Address | IP | VLAN")
13
12
 
14
13
  for k, v in ports_dict.items():
15
14
  table.add_row(
16
- k, v.get("link_status"), v.get("mac_address"), v.get("ip_address")
15
+ k,
16
+ v.get("link_status"),
17
+ "\n".join(
18
+ [
19
+ f"{key} | {values.get('ip_address')} | VLAN {values.get('vlan')}"
20
+ for key, values in v.get("mac_addresses", {}).items()
21
+ ]
22
+ ),
17
23
  )
18
24
 
19
25
  console.print(table)
File without changes
@@ -0,0 +1,260 @@
1
+ from enum import Enum
2
+
3
+ from gql import gql
4
+ import requests
5
+
6
+ from primitive.operating_systems.graphql.mutations import (
7
+ operating_system_create_mutation,
8
+ )
9
+ from primitive.operating_systems.graphql.queries import operating_system_list_query
10
+ from primitive.utils.actions import BaseAction
11
+ from primitive.utils.auth import guard
12
+
13
+ from primitive.utils.cache import get_operating_systems_cache
14
+ from pathlib import Path
15
+ from urllib.request import urlopen
16
+ import os
17
+ from loguru import logger
18
+
19
+ from primitive.utils.checksums import get_checksum_from_file, calculate_sha256
20
+
21
+
22
+ class OperatingSystems(BaseAction):
23
+ def __init__(self, primitive):
24
+ super().__init__(primitive)
25
+ self.remote_operating_systems = {
26
+ "ubuntu-24-04-3": {
27
+ "iso": "https://releases.ubuntu.com/24.04.3/ubuntu-24.04.3-desktop-amd64.iso",
28
+ "checksum": "https://releases.ubuntu.com/24.04.3/SHA256SUMS",
29
+ "checksum_file_type": self.OperatingSystemChecksumFileType.SHA256SUMS,
30
+ },
31
+ }
32
+
33
+ class OperatingSystemChecksumFileType(Enum):
34
+ SHA256SUMS = "SHA256SUMS"
35
+
36
+ def get_remote_operating_system_names(self):
37
+ return list(self.remote_operating_systems.keys())
38
+
39
+ def _download_remote_operating_system_iso(self, remote_operating_system_name):
40
+ operating_system_dir = Path(
41
+ get_operating_systems_cache() / remote_operating_system_name
42
+ )
43
+ iso_dir = Path(operating_system_dir / "iso")
44
+ os.makedirs(iso_dir, exist_ok=True)
45
+
46
+ operating_system_info = self.remote_operating_systems[
47
+ remote_operating_system_name
48
+ ]
49
+ iso_remote_url = operating_system_info["iso"]
50
+ iso_filename = iso_remote_url.split("/")[-1]
51
+ iso_file_path = Path(iso_dir / iso_filename)
52
+
53
+ if iso_file_path.exists() and iso_file_path.is_file():
54
+ logger.info("Operating system iso already downloaded.")
55
+ return iso_file_path
56
+
57
+ logger.info(
58
+ f"Downloading operating system '{remote_operating_system_name}' iso. This may take a few minutes..."
59
+ )
60
+
61
+ session = requests.Session()
62
+ with session.get(iso_remote_url, stream=True) as response:
63
+ response.raise_for_status()
64
+ with open(iso_file_path, "wb") as f:
65
+ for chunk in response.iter_content(chunk_size=8192):
66
+ if chunk:
67
+ f.write(chunk)
68
+ f.flush()
69
+
70
+ logger.info(
71
+ f"Successfully downloaded operating system iso to '{iso_file_path}'."
72
+ )
73
+
74
+ return iso_file_path
75
+
76
+ def _download_remote_operating_system_checksum(self, remote_operating_system_name):
77
+ operating_system_dir = Path(
78
+ get_operating_systems_cache() / remote_operating_system_name
79
+ )
80
+ checksum_dir = Path(operating_system_dir / "checksum")
81
+ os.makedirs(checksum_dir, exist_ok=True)
82
+
83
+ operating_system_info = self.remote_operating_systems[
84
+ remote_operating_system_name
85
+ ]
86
+ checksum_filename = operating_system_info["checksum"].split("/")[-1]
87
+
88
+ checksum_file_path = Path(checksum_dir / checksum_filename)
89
+ if checksum_file_path.exists() and checksum_file_path.is_file():
90
+ logger.info("Operating system checksum already downloaded.")
91
+ return checksum_file_path
92
+
93
+ logger.info(
94
+ f"Downloading operating system '{remote_operating_system_name}' checksum."
95
+ )
96
+
97
+ checksum_response = urlopen(operating_system_info["checksum"])
98
+ checksum_file_content = checksum_response.read()
99
+ with open(checksum_file_path, "wb") as f:
100
+ f.write(checksum_file_content)
101
+
102
+ logger.info(f"Successfully downloaded checksum to '{checksum_file_path}'.")
103
+
104
+ return checksum_file_path
105
+
106
+ def download_from_remote(self, remote_operating_system_name: str):
107
+ remote_operating_system_names = self.get_remote_operating_system_names()
108
+
109
+ if remote_operating_system_name not in remote_operating_system_names:
110
+ logger.error(
111
+ f"No such operating system '{remote_operating_system_name}'. Run 'primitive operating-systems list' for available operating systems."
112
+ )
113
+ raise ValueError(
114
+ f"No such operating system '{remote_operating_system_name}'."
115
+ )
116
+
117
+ iso_file_path = self._download_remote_operating_system_iso(
118
+ remote_operating_system_name
119
+ )
120
+ checksum_file_path = self._download_remote_operating_system_checksum(
121
+ remote_operating_system_name
122
+ )
123
+
124
+ return iso_file_path, checksum_file_path
125
+
126
+ def get_checksum_file_type(self, operating_system_name: str):
127
+ return self.remote_operating_systems[operating_system_name][
128
+ "checksum_file_type"
129
+ ]
130
+
131
+ def validate_checksum(
132
+ self,
133
+ operating_system_name: str,
134
+ iso_file_path: str,
135
+ checksum_file_path: str,
136
+ checksum_file_type: OperatingSystemChecksumFileType | None = None,
137
+ ):
138
+ checksum_file_type = (
139
+ checksum_file_type
140
+ if checksum_file_type
141
+ else self.get_checksum_file_type(operating_system_name)
142
+ )
143
+
144
+ match checksum_file_type:
145
+ case self.OperatingSystemChecksumFileType.SHA256SUMS:
146
+ return self._validate_sha256_sums_checksum(
147
+ iso_file_path, checksum_file_path
148
+ )
149
+ case _:
150
+ logger.error(f"Invalid checksum file type: {checksum_file_type}")
151
+ raise ValueError(f"Invalid checksum file type: {checksum_file_type}")
152
+
153
+ def _validate_sha256_sums_checksum(self, iso_file_path, checksum_file_path):
154
+ iso_file_name = Path(iso_file_path).name
155
+
156
+ remote_checksum = get_checksum_from_file(checksum_file_path, iso_file_name)
157
+ local_checksum = calculate_sha256(iso_file_path)
158
+ return remote_checksum == local_checksum
159
+
160
+ @guard
161
+ def create_operating_system(
162
+ self,
163
+ slug: str,
164
+ organization_id: str,
165
+ checksum_file_id: str,
166
+ checksum_file_type: str,
167
+ iso_file_id: str,
168
+ ):
169
+ mutation = gql(operating_system_create_mutation)
170
+ input = {
171
+ "slug": slug,
172
+ "organization": organization_id,
173
+ "checksumFile": checksum_file_id,
174
+ "checksumFileType": checksum_file_type,
175
+ "isoFile": iso_file_id,
176
+ }
177
+ variables = {"input": input}
178
+ result = self.primitive.session.execute(
179
+ mutation, variable_values=variables, get_execution_result=True
180
+ )
181
+ return result.data.get("operatingSystemCreate")
182
+
183
+ @guard
184
+ def get_operating_system_list(
185
+ self,
186
+ organization_id: str,
187
+ slug: str | None = None,
188
+ id: str | None = None,
189
+ ):
190
+ query = gql(operating_system_list_query)
191
+
192
+ variables = {
193
+ "filters": {
194
+ "organization": {"id": organization_id},
195
+ }
196
+ }
197
+
198
+ if slug:
199
+ variables["filters"]["slug"] = {"exact": slug}
200
+
201
+ if id:
202
+ variables["filters"]["id"] = id
203
+
204
+ result = self.primitive.session.execute(
205
+ query, variable_values=variables, get_execution_result=True
206
+ )
207
+
208
+ return result
209
+
210
+ @guard
211
+ def get_operating_system(
212
+ self, organization_id: str, slug: str | None = None, id: str | None = None
213
+ ):
214
+ if not (slug or id):
215
+ raise Exception("Slug or id must be provided.")
216
+ if slug and id:
217
+ raise Exception("Only one of slug or id must be provided.")
218
+
219
+ operating_system_list_result = self.get_operating_system_list(
220
+ organization_id=organization_id, slug=slug, id=id
221
+ )
222
+
223
+ edges = operating_system_list_result.data.get("operatingSystemList").get(
224
+ "edges", []
225
+ )
226
+
227
+ if len(edges) == 0:
228
+ if slug:
229
+ logger.error(f"No operating system found for slug '{slug}'.")
230
+ raise Exception(f"No operating system found for slug {slug}.")
231
+ else:
232
+ logger.error(f"No operating system found for ID {id}.")
233
+ raise Exception(f"No operating system found for ID {id}.")
234
+
235
+ return edges[0].get("node")
236
+
237
+ @guard
238
+ def is_slug_available(self, slug: str, organization_id: str):
239
+ query = gql(operating_system_list_query)
240
+
241
+ variables = {
242
+ "filters": {
243
+ "slug": {"exact": slug},
244
+ "organization": {"id": organization_id},
245
+ }
246
+ }
247
+
248
+ result = self.primitive.session.execute(
249
+ query, variable_values=variables, get_execution_result=True
250
+ )
251
+
252
+ count = result.data.get("operatingSystemList").get("totalCount")
253
+
254
+ return count == 0
255
+
256
+ def is_operating_system_cached(self, slug: str, directory: str | None = None):
257
+ cache_dir = Path(directory) if directory else get_operating_systems_cache()
258
+ cache_path = cache_dir / slug
259
+
260
+ return cache_path.exists()