tunnel-manager 0.0.5__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tunnel-manager might be problematic. Click here for more details.

@@ -1,62 +1,71 @@
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import sys
5
+ import argparse
6
+ import concurrent.futures
1
7
  import logging
2
8
  import os
3
9
  import paramiko
10
+ import yaml
4
11
 
5
12
 
6
13
  class Tunnel:
7
14
  def __init__(
8
15
  self,
9
16
  remote_host: str,
17
+ username: str = None,
18
+ password: str = None,
10
19
  port: int = 22,
11
20
  identity_file: str = None,
12
21
  certificate_file: str = None,
13
22
  proxy_command: str = None,
14
- log_file: str = None,
23
+ ssh_config_file: str = os.path.expanduser("~/.ssh/config"),
15
24
  ):
16
25
  """
17
26
  Initialize the Tunnel class.
18
27
 
19
28
  :param remote_host: The hostname or IP of the remote host.
29
+ :param username: The username for authentication (overrides config).
30
+ :param password: The password for authentication (if no identity_file).
31
+ :param port: The SSH port (default: 22).
20
32
  :param identity_file: Optional path to the private key file (overrides config).
21
- :param certificate_file: Optional path to the certificate file (overrides config, used for Teleport).
22
- :param proxy_command: Optional proxy command string (overrides config, used for Teleport proxying).
33
+ :param certificate_file: Optional path to the certificate file (overrides config).
34
+ :param proxy_command: Optional proxy command string (overrides config).
23
35
  :param log_file: Optional path to a log file for recording operations.
36
+ :param ssh_config_file: Optional path to a custom SSH config file (defaults to ~/.ssh/config).
24
37
  """
25
38
  self.remote_host = remote_host
39
+ self.username = username
40
+ self.password = password
26
41
  self.port = port
27
42
  self.ssh_client = None
28
43
  self.sftp = None
29
- self.logger = None
44
+ self.logger = logging.getLogger(__name__)
30
45
 
31
- # Load from ~/.ssh/config if not overridden
32
- ssh_config_path = os.path.expanduser("~/.ssh/config")
46
+ # Load SSH config from custom or default path
33
47
  self.ssh_config = paramiko.SSHConfig()
34
- if os.path.exists(ssh_config_path):
35
- with open(ssh_config_path) as f:
48
+ if os.path.exists(ssh_config_file) and os.path.isfile(ssh_config_file):
49
+ with open(ssh_config_file, "r") as f:
36
50
  self.ssh_config.parse(f)
51
+ self.logger.info(f"Loaded SSH config from: {ssh_config_file}")
52
+ else:
53
+ self.logger.warning(f"No SSH config found at: {ssh_config_file}")
37
54
  host_config = self.ssh_config.lookup(remote_host) or {}
38
55
 
56
+ self.username = username or host_config.get("user")
39
57
  self.identity_file = identity_file or (
40
- host_config.get("identityfile", [None])[0]
41
- if "identityfile" in host_config
58
+ host_config.get("identityfile")[0]
59
+ if host_config.get("identityfile")
42
60
  else None
43
61
  )
44
62
  self.certificate_file = certificate_file or host_config.get("certificatefile")
45
63
  self.proxy_command = proxy_command or host_config.get("proxycommand")
46
64
 
47
- if not self.identity_file:
48
- raise ValueError(
49
- "Identity file must be provided either via parameter or in ~/.ssh/config."
50
- )
51
-
52
- if log_file:
53
- logging.basicConfig(
54
- filename=log_file,
55
- level=logging.INFO,
56
- format="%(asctime)s - %(levelname)s - %(message)s",
57
- )
58
- self.logger = logging.getLogger(__name__)
59
- self.logger.info(f"Tunnel initialized for host: {remote_host}")
65
+ if not self.username:
66
+ raise ValueError("Username must be provided via parameter or SSH config.")
67
+ if not self.identity_file and not self.password:
68
+ raise ValueError("Either identity_file or password must be provided.")
60
69
 
61
70
  def connect(self):
62
71
  """
@@ -67,7 +76,7 @@ class Tunnel:
67
76
  and self.ssh_client.get_transport()
68
77
  and self.ssh_client.get_transport().is_active()
69
78
  ):
70
- return # Already connected
79
+ return
71
80
 
72
81
  self.ssh_client = paramiko.SSHClient()
73
82
  self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
@@ -75,30 +84,38 @@ class Tunnel:
75
84
  proxy = None
76
85
  if self.proxy_command:
77
86
  proxy = paramiko.ProxyCommand(self.proxy_command)
78
- if self.logger:
79
- self.logger.info(f"Using proxy command: {self.proxy_command}")
80
-
81
- private_key = paramiko.RSAKey.from_private_key_file(self.identity_file)
82
- if self.certificate_file:
83
- private_key.load_certificate(self.certificate_file)
84
- if self.logger:
85
- self.logger.info(f"Loaded certificate: {self.certificate_file}")
87
+ self.logger.info(f"Using proxy command: {self.proxy_command}")
86
88
 
87
89
  try:
88
- self.ssh_client.connect(
89
- self.remote_host,
90
- port=self.port,
91
- pkey=private_key,
92
- sock=proxy,
93
- auth_timeout=30,
94
- look_for_keys=False,
95
- allow_agent=False,
96
- )
97
- if self.logger:
98
- self.logger.info(f"Connected to {self.remote_host}")
90
+ if self.identity_file:
91
+ private_key = paramiko.RSAKey.from_private_key_file(self.identity_file)
92
+ if self.certificate_file:
93
+ private_key.load_certificate(self.certificate_file)
94
+ self.logger.info(f"Loaded certificate: {self.certificate_file}")
95
+ self.ssh_client.connect(
96
+ self.remote_host,
97
+ port=self.port,
98
+ username=self.username,
99
+ pkey=private_key,
100
+ sock=proxy,
101
+ auth_timeout=30,
102
+ look_for_keys=False,
103
+ allow_agent=False,
104
+ )
105
+ else:
106
+ self.ssh_client.connect(
107
+ self.remote_host,
108
+ port=self.port,
109
+ username=self.username,
110
+ password=self.password,
111
+ sock=proxy,
112
+ auth_timeout=30,
113
+ look_for_keys=False,
114
+ allow_agent=False,
115
+ )
116
+ self.logger.info(f"Connected to {self.remote_host}")
99
117
  except Exception as e:
100
- if self.logger:
101
- self.logger.error(f"Connection failed: {str(e)}")
118
+ self.logger.error(f"Connection failed: {str(e)}")
102
119
  raise
103
120
 
104
121
  def run_command(self, command):
@@ -113,14 +130,12 @@ class Tunnel:
113
130
  stdin, stdout, stderr = self.ssh_client.exec_command(command)
114
131
  out = stdout.read().decode("utf-8").strip()
115
132
  err = stderr.read().decode("utf-8").strip()
116
- if self.logger:
117
- self.logger.info(
118
- f"Command executed: {command}\nOutput: {out}\nError: {err}"
119
- )
133
+ self.logger.info(
134
+ f"Command executed: {command}\nOutput: {out}\nError: {err}"
135
+ )
120
136
  return out, err
121
137
  except Exception as e:
122
- if self.logger:
123
- self.logger.error(f"Command execution failed: {str(e)}")
138
+ self.logger.error(f"Command execution failed: {str(e)}")
124
139
  raise
125
140
 
126
141
  def send_file(self, local_path, remote_path):
@@ -135,11 +150,9 @@ class Tunnel:
135
150
  if not self.sftp:
136
151
  self.sftp = self.ssh_client.open_sftp()
137
152
  self.sftp.put(local_path, remote_path)
138
- if self.logger:
139
- self.logger.info(f"File sent: {local_path} -> {remote_path}")
153
+ self.logger.info(f"File sent: {local_path} -> {remote_path}")
140
154
  except Exception as e:
141
- if self.logger:
142
- self.logger.error(f"File send failed: {str(e)}")
155
+ self.logger.error(f"File send failed: {str(e)}")
143
156
  raise
144
157
  finally:
145
158
  if self.sftp:
@@ -158,33 +171,721 @@ class Tunnel:
158
171
  if not self.sftp:
159
172
  self.sftp = self.ssh_client.open_sftp()
160
173
  self.sftp.get(remote_path, local_path)
161
- if self.logger:
162
- self.logger.info(f"File received: {remote_path} -> {local_path}")
174
+ self.logger.info(f"File received: {remote_path} -> {local_path}")
163
175
  except Exception as e:
164
- if self.logger:
165
- self.logger.error(f"File receive failed: {str(e)}")
176
+ self.logger.error(f"File receive failed: {str(e)}")
166
177
  raise
167
178
  finally:
168
179
  if self.sftp:
169
180
  self.sftp.close()
170
181
  self.sftp = None
171
182
 
183
+ def check_ssh_server(self):
184
+ """
185
+ Check if the SSH server is running and configured for key-based auth on the remote host.
186
+ :return: Tuple (bool, str) indicating if SSH server is running and any error message.
187
+ """
188
+ try:
189
+ self.connect()
190
+ out, err = self.run_command(
191
+ "systemctl status sshd || ps aux | grep '[s]shd'"
192
+ )
193
+ if "running" in out.lower() or "sshd" in out.lower():
194
+ out, err = self.run_command(
195
+ "grep '^PubkeyAuthentication' /etc/ssh/sshd_config"
196
+ )
197
+ if "PubkeyAuthentication yes" in out:
198
+ return True, "SSH server running with key-based auth enabled."
199
+ return False, "SSH server running but key-based auth not enabled."
200
+ return False, "SSH server not running."
201
+ except Exception as e:
202
+ self.logger.error(f"Failed to check SSH server: {str(e)}")
203
+ return False, f"Failed to check SSH server: {str(e)}"
204
+ finally:
205
+ self.close()
206
+
207
+ def test_key_auth(self, local_key_path):
208
+ """
209
+ Test if key-based authentication works for the remote host.
210
+ :param local_key_path: Path to the private key to test.
211
+ :return: Tuple (bool, str) indicating success and any error message.
212
+ """
213
+ local_key_path = os.path.expanduser(local_key_path)
214
+ try:
215
+ temp_tunnel = Tunnel(
216
+ remote_host=self.remote_host,
217
+ username=self.username,
218
+ identity_file=local_key_path,
219
+ )
220
+ temp_tunnel.connect()
221
+ temp_tunnel.close()
222
+ return True, "Key-based authentication successful."
223
+ except Exception as e:
224
+ self.logger.error(f"Key auth test failed: {str(e)}")
225
+ return False, f"Key auth test failed: {str(e)}"
226
+
172
227
  def close(self):
173
228
  """
174
229
  Close the SSH connection.
175
230
  """
176
231
  if self.ssh_client:
177
232
  self.ssh_client.close()
178
- if self.logger:
179
- self.logger.info(f"Connection closed for {self.remote_host}")
233
+ self.logger.info(f"Connection closed for {self.remote_host}")
180
234
  self.ssh_client = None
181
235
 
236
+ def setup_passwordless_ssh(
237
+ self, local_key_path=os.path.expanduser("~/.ssh/id_rsa")
238
+ ):
239
+ """
240
+ Set up passwordless SSH by copying a public key to the remote host.
241
+ Requires password-based authentication to be configured.
242
+
243
+ :param local_key_path: Path to the local private key (public key is assumed to be .pub).
244
+ """
245
+ if not self.password:
246
+ raise ValueError("Password-based authentication required for setup.")
247
+
248
+ local_key_path = os.path.expanduser(local_key_path)
249
+ pub_key_path = local_key_path + ".pub"
250
+
251
+ if not os.path.exists(pub_key_path):
252
+ os.system(f"ssh-keygen -t rsa -b 4096 -f {local_key_path} -N ''")
253
+ self.logger.info(f"Generated key pair: {local_key_path}, {pub_key_path}")
254
+
255
+ with open(pub_key_path, "r") as f:
256
+ pub_key = f.read().strip()
257
+
258
+ try:
259
+ self.connect()
260
+ self.run_command("mkdir -p ~/.ssh && chmod 700 ~/.ssh")
261
+ self.run_command(f"echo '{pub_key}' >> ~/.ssh/authorized_keys")
262
+ self.run_command("chmod 600 ~/.ssh/authorized_keys")
263
+ self.logger.info(
264
+ f"Set up passwordless SSH for {self.username}@{self.remote_host}"
265
+ )
266
+ except Exception as e:
267
+ self.logger.error(f"Failed to set up passwordless SSH: {str(e)}")
268
+ raise
269
+ finally:
270
+ self.close()
271
+
272
+ @staticmethod
273
+ def execute_on_all(inventory, func, group="all", parallel=False, max_threads=5):
274
+ """
275
+ Execute a function on all hosts in the specified group of the YAML inventory, sequentially or in parallel.
276
+ :param inventory: Path to the YAML inventory file.
277
+ :param func: Function to execute, takes host dict as argument.
278
+ :param group: Inventory group to target (default: 'all').
279
+ :param parallel: Whether to run in parallel using threads.
280
+ :param max_threads: Maximum number of threads if parallel.
281
+ """
282
+ logger = logging.getLogger("Tunnel")
283
+ logger.info(f"Processing inventory '{inventory}' for group '{group}'")
284
+ print(f"Loading inventory '{inventory}' for group '{group}'...")
285
+
286
+ try:
287
+ with open(inventory, "r") as f:
288
+ inventory_data = yaml.safe_load(f)
289
+ logger.debug(f"Loaded inventory data: {inventory_data}")
290
+ except FileNotFoundError:
291
+ logger.error(f"Inventory file not found: {inventory}")
292
+ print(f"Error: Inventory file not found: {inventory}", file=sys.stderr)
293
+ raise
294
+ except yaml.YAMLError as e:
295
+ logger.error(f"Failed to parse inventory file: {str(e)}")
296
+ print(f"Error: Failed to parse inventory file: {str(e)}", file=sys.stderr)
297
+ raise
298
+
299
+ hosts = []
300
+ if (
301
+ group in inventory_data
302
+ and isinstance(inventory_data[group], dict)
303
+ and "hosts" in inventory_data[group]
304
+ and isinstance(inventory_data[group]["hosts"], dict)
305
+ ):
306
+ for host, vars in inventory_data[group]["hosts"].items():
307
+ host_entry = {
308
+ "hostname": vars.get("ansible_host", host),
309
+ "username": vars.get("ansible_user"),
310
+ "password": vars.get("ansible_ssh_pass"),
311
+ "key_path": vars.get("ansible_ssh_private_key_file"),
312
+ }
313
+ if not host_entry["username"]:
314
+ logger.error(
315
+ f"No username specified for host {host_entry['hostname']}"
316
+ )
317
+ print(
318
+ f"Error: No username specified for host {host_entry['hostname']}",
319
+ file=sys.stderr,
320
+ )
321
+ continue
322
+ logger.debug(f"Added host: {host_entry['hostname']}")
323
+ hosts.append(host_entry)
324
+ else:
325
+ logger.error(
326
+ f"Group '{group}' not found in inventory or invalid (hosts not a dict)"
327
+ )
328
+ print(
329
+ f"Error: Group '{group}' not found in inventory or invalid (hosts not a dict)",
330
+ file=sys.stderr,
331
+ )
332
+ raise ValueError(f"Group '{group}' not found in inventory or invalid")
333
+
334
+ logger.info(f"Found {len(hosts)} hosts in group '{group}'")
335
+ print(f"Found {len(hosts)} hosts in group '{group}'")
336
+
337
+ if not hosts:
338
+ logger.warning(f"No valid hosts found in group '{group}'")
339
+ print(f"Warning: No valid hosts found in group '{group}'")
340
+ return
341
+
342
+ if parallel:
343
+ with concurrent.futures.ThreadPoolExecutor(
344
+ max_workers=max_threads
345
+ ) as executor:
346
+ futures = [executor.submit(func, host) for host in hosts]
347
+ for future in concurrent.futures.as_completed(futures):
348
+ try:
349
+ future.result()
350
+ except Exception as e:
351
+ logger.error(f"Error in parallel execution: {str(e)}")
352
+ print(f"Error in parallel execution: {str(e)}", file=sys.stderr)
353
+ else:
354
+ for host in hosts:
355
+ func(host)
356
+ print(f"Completed processing group '{group}'")
357
+
358
+ def remove_host_key(
359
+ self, known_hosts_path=os.path.expanduser("~/.ssh/known_hosts")
360
+ ) -> str:
361
+ """
362
+ Remove the host key for the remote host from the known_hosts file.
363
+ :param known_hosts_path: Path to the known_hosts file (default: ~/.ssh/known_hosts).
364
+ """
365
+ known_hosts_path = os.path.expanduser(known_hosts_path)
366
+ kh = paramiko.HostKeys()
367
+ if os.path.exists(known_hosts_path):
368
+ kh.load(known_hosts_path)
369
+ if self.remote_host in kh:
370
+ del kh[self.remote_host]
371
+ kh.save(known_hosts_path)
372
+ self.logger.info(
373
+ f"Removed host key for {self.remote_host} from {known_hosts_path}"
374
+ )
375
+ return (
376
+ f"Removed host key for {self.remote_host} from {known_hosts_path}"
377
+ )
378
+ else:
379
+ self.logger.warning(
380
+ f"No host key found for {self.remote_host} in {known_hosts_path}"
381
+ )
382
+ return f"No host key found for {self.remote_host} in {known_hosts_path}"
383
+ else:
384
+ self.logger.warning(f"No known_hosts file at {known_hosts_path}")
385
+ return f"No known_hosts file at {known_hosts_path}"
386
+
387
+ def copy_ssh_config(
388
+ self, local_config_path, remote_config_path=os.path.expanduser("~/.ssh/config")
389
+ ):
390
+ """
391
+ Copy a local SSH config to the remote host’s ~/.ssh/config.
392
+ :param local_config_path: Path to the local config file.
393
+ :param remote_config_path: Path on remote (default ~/.ssh/config).
394
+ """
395
+ self.connect()
396
+ self.run_command("mkdir -p ~/.ssh && chmod 700 ~/.ssh")
397
+ self.send_file(local_config_path, remote_config_path)
398
+ self.run_command(f"chmod 600 {remote_config_path}")
399
+ self.logger.info(
400
+ f"Copied SSH config to {remote_config_path} on {self.remote_host}"
401
+ )
402
+
403
+ def rotate_ssh_key(self, new_key_path):
404
+ """
405
+ Rotate the SSH key by generating a new pair and updating authorized_keys.
406
+ :param new_key_path: Path for the new private key.
407
+ """
408
+ new_key_path = os.path.expanduser(new_key_path)
409
+ new_pub_path = new_key_path + ".pub"
410
+ if not os.path.exists(new_key_path):
411
+ os.system(f"ssh-keygen -t rsa -b 4096 -f {new_key_path} -N ''")
412
+ self.logger.info(f"Generated new key pair: {new_key_path}")
413
+
414
+ with open(new_pub_path, "r") as f:
415
+ new_pub = f.read().strip()
416
+
417
+ old_pub = None
418
+ if self.identity_file:
419
+ old_key_path = os.path.expanduser(self.identity_file)
420
+ old_pub_path = old_key_path + ".pub"
421
+ if os.path.exists(old_pub_path):
422
+ with open(old_pub_path, "r") as f:
423
+ old_pub = f.read().strip()
424
+
425
+ self.connect()
426
+ out, err = self.run_command("cat ~/.ssh/authorized_keys")
427
+ auth_keys = out.splitlines()
428
+ new_auth = [
429
+ line
430
+ for line in auth_keys
431
+ if line.strip() and (old_pub is None or line.strip() != old_pub)
432
+ ]
433
+ new_auth.append(new_pub)
434
+
435
+ temp_file = "/tmp/authorized_keys.new"
436
+ # Construct the command string without escape sequences in f-string
437
+ new_auth_joined = "\n".join(new_auth)
438
+ self.run_command(f"echo '{new_auth_joined}' > {temp_file}")
439
+ self.run_command(f"mv {temp_file} ~/.ssh/authorized_keys")
440
+ self.run_command("chmod 600 ~/.ssh/authorized_keys")
441
+
442
+ self.identity_file = new_key_path
443
+ self.password = None
444
+ self.logger.info(f"Rotated key to {new_key_path} on {self.remote_host}")
445
+ logging.info(
446
+ f"Please update SSH config for {self.remote_host} IdentityFile to {new_key_path}"
447
+ )
448
+
449
+ @staticmethod
450
+ def setup_all_passwordless_ssh(
451
+ inventory,
452
+ shared_key_path=os.path.expanduser("~/.ssh/id_shared"),
453
+ group="all",
454
+ parallel=False,
455
+ max_threads=5,
456
+ ):
457
+ """
458
+ Set up passwordless SSH for all hosts in the specified group of the YAML inventory.
459
+ :param inventory: Path to the YAML inventory file.
460
+ :param shared_key_path: Path to a shared private key (optional, generates if missing).
461
+ :param group: Inventory group to target (default: 'all').
462
+ :param parallel: Run in parallel.
463
+ :param max_threads: Max threads for parallel.
464
+ """
465
+ shared_key_path = os.path.expanduser(shared_key_path)
466
+ shared_pub_key_path = shared_key_path + ".pub"
467
+ if not os.path.exists(shared_key_path):
468
+ os.system(f"ssh-keygen -t rsa -b 4096 -f {shared_key_path} -N ''")
469
+ logging.info(
470
+ f"Generated shared key pair: {shared_key_path}, {shared_pub_key_path}"
471
+ )
472
+
473
+ with open(shared_pub_key_path, "r") as f:
474
+ shared_pub_key = f.read().strip()
475
+
476
+ def setup_host(host):
477
+ hostname = host["hostname"]
478
+ username = host["username"]
479
+ password = host["password"]
480
+ key_path = host.get("key_path", shared_key_path)
481
+
482
+ logging.info(f"\nSetting up {username}@{hostname}...")
483
+
484
+ tunnel = Tunnel(
485
+ remote_host=hostname,
486
+ username=username,
487
+ password=password,
488
+ )
489
+ tunnel.remove_host_key()
490
+ tunnel.setup_passwordless_ssh(local_key_path=key_path)
491
+
492
+ try:
493
+ tunnel.connect()
494
+ tunnel.run_command(f"echo '{shared_pub_key}' >> ~/.ssh/authorized_keys")
495
+ tunnel.run_command("chmod 600 ~/.ssh/authorized_keys")
496
+ logging.info(f"Added shared key to {username}@{hostname}")
497
+ except Exception as e:
498
+ logging.error(
499
+ f"Failed to add shared key to {username}@{hostname}: {str(e)}"
500
+ )
501
+ finally:
502
+ tunnel.close()
503
+
504
+ result, msg = tunnel.test_key_auth(shared_key_path)
505
+ logging.info(f"Key auth test for {username}@{hostname}: {msg}")
506
+
507
+ Tunnel.execute_on_all(inventory, setup_host, group, parallel, max_threads)
508
+
509
+ @staticmethod
510
+ def run_command_on_all(
511
+ inventory, command, group="all", parallel=False, max_threads=5
512
+ ):
513
+ """
514
+ Run a shell command on all hosts in the specified group of the YAML inventory.
515
+ :param inventory: Path to the YAML inventory file.
516
+ :param command: The shell command to run.
517
+ :param group: Inventory group to target (default: 'all').
518
+ :param parallel: Run in parallel.
519
+ :param max_threads: Max threads for parallel.
520
+ """
521
+ logger = logging.getLogger("Tunnel")
522
+ logger.info(f"Running command '{command}' on group '{group}'")
523
+ print(f"Executing command '{command}' on group '{group}'...")
524
+
525
+ def run_host(host):
526
+ try:
527
+ tunnel = Tunnel(
528
+ remote_host=host["hostname"],
529
+ username=host["username"],
530
+ password=host.get("password"),
531
+ identity_file=host.get("key_path"),
532
+ )
533
+ out, err = tunnel.run_command(command)
534
+ logger.info(
535
+ f"Host {host['hostname']}: In: {command}, Out: {out}, Err: {err}"
536
+ )
537
+ print(
538
+ f"Host {host['hostname']}:\nInput: {command}\nOutput: {out}\nError: {err}"
539
+ )
540
+ tunnel.close()
541
+ except Exception as e:
542
+ logger.error(f"Failed to run command on {host['hostname']}: {str(e)}")
543
+ print(f"Error on {host['hostname']}: {str(e)}", file=sys.stderr)
544
+
545
+ try:
546
+ Tunnel.execute_on_all(inventory, run_host, group, parallel, max_threads)
547
+ print(f"Completed command execution on group '{group}'")
548
+ except Exception as e:
549
+ logger.error(f"Failed to execute command on group '{group}': {str(e)}")
550
+ print(
551
+ f"Error executing command on group '{group}': {str(e)}", file=sys.stderr
552
+ )
553
+ raise
554
+
555
+ @staticmethod
556
+ def copy_ssh_config_on_all(
557
+ inventory,
558
+ local_config_path,
559
+ remote_config_path=os.path.expanduser("~/.ssh/config"),
560
+ group="all",
561
+ parallel=False,
562
+ max_threads=5,
563
+ ):
564
+ """
565
+ Copy local SSH config to all hosts in the specified group of the YAML inventory.
566
+ :param inventory: Path to the YAML inventory file.
567
+ :param local_config_path: Local SSH config path.
568
+ :param remote_config_path: Remote path (default ~/.ssh/config).
569
+ :param group: Inventory group to target (default: 'all').
570
+ :param parallel: Run in parallel.
571
+ :param max_threads: Max threads for parallel.
572
+ """
573
+
574
+ def copy_host(host):
575
+ tunnel = Tunnel(
576
+ remote_host=host["hostname"],
577
+ username=host["username"],
578
+ password=host.get("password"),
579
+ identity_file=host.get("key_path"),
580
+ )
581
+ tunnel.copy_ssh_config(local_config_path, remote_config_path)
582
+ tunnel.close()
583
+
584
+ Tunnel.execute_on_all(inventory, copy_host, group, parallel, max_threads)
585
+
586
+ @staticmethod
587
+ def rotate_ssh_key_on_all(
588
+ inventory,
589
+ key_prefix=os.path.expanduser("~/.ssh/id_"),
590
+ group="all",
591
+ parallel=False,
592
+ max_threads=5,
593
+ ):
594
+ """
595
+ Rotate SSH keys for all hosts in the specified group of the YAML inventory.
596
+ :param inventory: Path to the YAML inventory file.
597
+ :param key_prefix: Prefix for new key paths (appends hostname).
598
+ :param group: Inventory group to target (default: 'all').
599
+ :param parallel: Run in parallel.
600
+ :param max_threads: Max threads for parallel.
601
+ """
602
+
603
+ def rotate_host(host):
604
+ new_key_path = os.path.expanduser(key_prefix + host["hostname"])
605
+ tunnel = Tunnel(
606
+ remote_host=host["hostname"],
607
+ username=host["username"],
608
+ password=host.get("password"),
609
+ identity_file=host.get("key_path"),
610
+ )
611
+ tunnel.rotate_ssh_key(new_key_path)
612
+ logging.info(
613
+ f"Rotated key for {host['hostname']}. Update inventory key_path to {new_key_path} if needed."
614
+ )
615
+ tunnel.close()
616
+
617
+ Tunnel.execute_on_all(inventory, rotate_host, group, parallel, max_threads)
618
+
619
+ @staticmethod
620
+ def send_file_on_all(
621
+ inventory,
622
+ local_path,
623
+ remote_path,
624
+ group="all",
625
+ parallel=False,
626
+ max_threads=5,
627
+ ):
628
+ """
629
+ Upload a file to all hosts in the specified group of the YAML inventory.
630
+ :param inventory: Path to the YAML inventory file.
631
+ :param local_path: Path to the local file to upload.
632
+ :param remote_path: Path on the remote hosts to save the file.
633
+ :param group: Inventory group to target (default: 'all').
634
+ :param parallel: Run in parallel.
635
+ :param max_threads: Max threads for parallel execution.
636
+ """
637
+
638
+ def send_host(host):
639
+ tunnel = Tunnel(
640
+ remote_host=host["hostname"],
641
+ username=host["username"],
642
+ password=host.get("password"),
643
+ identity_file=host.get("key_path"),
644
+ )
645
+ tunnel.send_file(local_path, remote_path)
646
+ logging.info(f"Host {host['hostname']}: File uploaded to {remote_path}")
647
+ tunnel.close()
648
+
649
+ if not os.path.exists(local_path):
650
+ raise ValueError(f"Local file does not exist: {local_path}")
651
+
652
+ Tunnel.execute_on_all(inventory, send_host, group, parallel, max_threads)
653
+
654
+ @staticmethod
655
+ def receive_file_on_all(
656
+ inventory,
657
+ remote_path: str,
658
+ local_path_prefix,
659
+ group="all",
660
+ parallel=False,
661
+ max_threads=5,
662
+ ):
663
+ """
664
+ Download a file from all hosts in the specified group of the YAML inventory.
665
+ :param inventory: Path to the YAML inventory file.
666
+ :param remote_path: Path on the remote hosts to download the file from.
667
+ :param local_path_prefix: Local directory path prefix to save files (creates host-specific subdirectories).
668
+ :param group: Inventory group to target (default: 'all').
669
+ :param parallel: Run in parallel.
670
+ :param max_threads: Max threads for parallel execution.
671
+ """
672
+
673
+ def receive_host(host):
674
+ host_dir = os.path.join(local_path_prefix, host["hostname"])
675
+ os.makedirs(host_dir, exist_ok=True)
676
+ local_path = os.path.join(f"{host_dir}", os.path.basename(remote_path))
677
+ tunnel = Tunnel(
678
+ remote_host=host["hostname"],
679
+ username=host["username"],
680
+ password=host.get("password"),
681
+ identity_file=host.get("key_path"),
682
+ )
683
+ tunnel.receive_file(remote_path, local_path)
684
+ logging.info(f"Host {host['hostname']}: File downloaded to {local_path}")
685
+ tunnel.close()
686
+
687
+ os.makedirs(local_path_prefix, exist_ok=True)
688
+ Tunnel.execute_on_all(inventory, receive_host, group, parallel, max_threads)
689
+
690
+
691
+ def tunnel_manager():
692
+ parser = argparse.ArgumentParser(description="Tunnel Manager CLI")
693
+ parser.add_argument("--log-file", help="Log to this file (default: console output)")
694
+
695
+ subparsers = parser.add_subparsers(dest="command", required=True)
696
+
697
+ # Setup-all command
698
+ setup_parser = subparsers.add_parser("setup-all", help="Setup passwordless for all")
699
+ setup_parser.add_argument("--inventory", help="YAML inventory path")
700
+ setup_parser.add_argument(
701
+ "--shared-key-path",
702
+ default="~/.ssh/id_shared",
703
+ help="Path to shared private key",
704
+ )
705
+ setup_parser.add_argument(
706
+ "--group", default="all", help="Inventory group to target (default: all)"
707
+ )
708
+ setup_parser.add_argument("--parallel", action="store_true", help="Run in parallel")
709
+ setup_parser.add_argument(
710
+ "--max-threads", type=int, default=5, help="Max threads for parallel execution"
711
+ )
712
+
713
+ # Run-command command
714
+ run_parser = subparsers.add_parser("run-command", help="Run command on all")
715
+ run_parser.add_argument("--inventory", help="YAML inventory path")
716
+ run_parser.add_argument("--remote-command", help="Shell command to run")
717
+ run_parser.add_argument(
718
+ "--group", default="all", help="Inventory group to target (default: all)"
719
+ )
720
+ run_parser.add_argument("--parallel", action="store_true", help="Run in parallel")
721
+ run_parser.add_argument(
722
+ "--max-threads", type=int, default=5, help="Max threads for parallel execution"
723
+ )
724
+
725
+ # Copy-config command
726
+ copy_parser = subparsers.add_parser("copy-config", help="Copy SSH config to all")
727
+ copy_parser.add_argument("--inventory", help="YAML inventory path")
728
+ copy_parser.add_argument(
729
+ "--local-config-path", default="~/.ssh/config", help="Local SSH config path"
730
+ )
731
+ copy_parser.add_argument(
732
+ "--remote-config-path",
733
+ default="~/.ssh/config",
734
+ help="Remote path (default ~/.ssh/config)",
735
+ )
736
+ copy_parser.add_argument(
737
+ "--group", default="all", help="Inventory group to target (default: all)"
738
+ )
739
+ copy_parser.add_argument("--parallel", action="store_true", help="Run in parallel")
740
+ copy_parser.add_argument(
741
+ "--max-threads", type=int, default=5, help="Max threads for parallel execution"
742
+ )
743
+
744
+ # Rotate-key command
745
+ rotate_parser = subparsers.add_parser("rotate-key", help="Rotate keys for all")
746
+ rotate_parser.add_argument("--inventory", help="YAML inventory path")
747
+ rotate_parser.add_argument(
748
+ "--key-prefix",
749
+ default="~/.ssh/id_",
750
+ help="Prefix for new key paths (appends hostname)",
751
+ )
752
+ rotate_parser.add_argument(
753
+ "--group", default="all", help="Inventory group to target (default: all)"
754
+ )
755
+ rotate_parser.add_argument(
756
+ "--parallel", action="store_true", help="Run in parallel"
757
+ )
758
+ rotate_parser.add_argument(
759
+ "--max-threads", type=int, default=5, help="Max threads for parallel execution"
760
+ )
761
+
762
+ # Send-file command
763
+ send_parser = subparsers.add_parser(
764
+ "send-file", help="Upload file to all hosts in inventory"
765
+ )
766
+ send_parser.add_argument("--inventory", help="YAML inventory path")
767
+ send_parser.add_argument("--local-path", help="Local file path to upload")
768
+ send_parser.add_argument("--remote-path", help="Remote destination path")
769
+ send_parser.add_argument(
770
+ "--group", default="all", help="Inventory group to target (default: all)"
771
+ )
772
+ send_parser.add_argument("--parallel", action="store_true", help="Run in parallel")
773
+ send_parser.add_argument(
774
+ "--max-threads", type=int, default=5, help="Max threads for parallel execution"
775
+ )
776
+
777
+ # Receive-file command
778
+ receive_parser = subparsers.add_parser(
779
+ "receive-file", help="Download file from all hosts in inventory"
780
+ )
781
+ receive_parser.add_argument("--inventory", help="YAML inventory path")
782
+ receive_parser.add_argument("--remote-path", help="Remote file path to download")
783
+ receive_parser.add_argument(
784
+ "--local-path-prefix", help="Local directory path prefix to save files"
785
+ )
786
+ receive_parser.add_argument(
787
+ "--group", default="all", help="Inventory group to target (default: all)"
788
+ )
789
+ receive_parser.add_argument(
790
+ "--parallel", action="store_true", help="Run in parallel"
791
+ )
792
+ receive_parser.add_argument(
793
+ "--max-threads", type=int, default=5, help="Max threads for parallel execution"
794
+ )
795
+
796
+ args = parser.parse_args()
797
+
798
+ # Ensure log file directory exists
799
+ if args.log_file:
800
+ log_dir = (
801
+ os.path.dirname(os.path.abspath(args.log_file))
802
+ if os.path.dirname(args.log_file)
803
+ else os.getcwd()
804
+ )
805
+ os.makedirs(log_dir, exist_ok=True)
806
+ try:
807
+ logging.basicConfig(
808
+ filename=args.log_file,
809
+ level=logging.DEBUG,
810
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
811
+ )
812
+ except PermissionError as e:
813
+ print(
814
+ f"Error: Cannot write to log file '{args.log_file}': {str(e)}",
815
+ file=sys.stderr,
816
+ )
817
+ sys.exit(1)
818
+ else:
819
+ logging.basicConfig(
820
+ level=logging.DEBUG,
821
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
822
+ )
823
+
824
+ logger = logging.getLogger("Tunnel")
825
+ logger.debug(
826
+ f"Starting Tunnel Automation with command: {args.command}, args: {vars(args)}"
827
+ )
828
+ print(f"Starting Tunnel Automation with command: {args.command}")
829
+
830
+ try:
831
+ if args.command == "setup-all":
832
+ Tunnel.setup_all_passwordless_ssh(
833
+ args.inventory,
834
+ args.shared_key_path,
835
+ args.group,
836
+ args.parallel,
837
+ args.max_threads,
838
+ )
839
+ elif args.command == "run-command":
840
+ Tunnel.run_command_on_all(
841
+ args.inventory,
842
+ args.remote_command,
843
+ args.group,
844
+ args.parallel,
845
+ args.max_threads,
846
+ )
847
+ elif args.command == "copy-config":
848
+ Tunnel.copy_ssh_config_on_all(
849
+ args.inventory,
850
+ args.local_config_path,
851
+ args.remote_config_path,
852
+ args.group,
853
+ args.parallel,
854
+ args.max_threads,
855
+ )
856
+ elif args.command == "rotate-key":
857
+ Tunnel.rotate_ssh_key_on_all(
858
+ args.inventory,
859
+ args.key_prefix,
860
+ args.group,
861
+ args.parallel,
862
+ args.max_threads,
863
+ )
864
+ elif args.command == "send-file":
865
+ Tunnel.send_file_on_all(
866
+ args.inventory,
867
+ args.local_path,
868
+ args.remote_path,
869
+ args.group,
870
+ args.parallel,
871
+ args.max_threads,
872
+ )
873
+ elif args.command == "receive-file":
874
+ Tunnel.receive_file_on_all(
875
+ args.inventory,
876
+ args.remote_path,
877
+ args.local_path_prefix,
878
+ args.group,
879
+ args.parallel,
880
+ args.max_threads,
881
+ )
882
+ logger.debug("Automation Complete")
883
+ print("Automation Complete")
884
+ except Exception as e:
885
+ logger.error(f"Automation failed: {str(e)}")
886
+ print(f"Error: Automation failed: {str(e)}", file=sys.stderr)
887
+ sys.exit(1)
888
+
182
889
 
183
- # Example usage (commented out):
184
- # tunnel = Tunnel("your-remote-host.example.com", log_file="tunnel.log")
185
- # tunnel.connect()
186
- # out, err = tunnel.run_command("ls -la")
187
- # print(out)
188
- # tunnel.send_file("/local/file.txt", "/remote/file.txt")
189
- # tunnel.receive_file("/remote/file.txt", "/local/downloaded.txt")
190
- # tunnel.close()
890
+ if __name__ == "__main__":
891
+ tunnel_manager()