p3lib 1.1.108__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
p3lib/ssh.py ADDED
@@ -0,0 +1,935 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import socket
5
+ from paramiko import SSHClient, AutoAddPolicy, AuthenticationException, SFTPClient
6
+ import logging
7
+ import threading
8
+ import socketserver
9
+ import select
10
+
11
+ from getpass import getuser, getpass
12
+
13
+ class SSHError(Exception):
14
+ pass
15
+
16
+ # -------------------------------------------------------------------------------
17
+
18
+ class ExtendedSSHClient(SSHClient):
19
+ """@brief The ssh client class"""
20
+
21
+ CMD_POLL_SECONDS = 0.1
22
+
23
+ @staticmethod
24
+ def GetLines(bytes):
25
+ """@brief Split the text into lines.
26
+ @param bytes The bytes containing text to be split into lines of text.
27
+ @return A List of lines of text."""
28
+ text = bytes.decode('utf-8')
29
+ lines = []
30
+ if len(text) > 0:
31
+ elems = text.split("\n")
32
+ lines = ExtendedSSHClient.StripEOL(elems)
33
+ return lines
34
+
35
+ @staticmethod
36
+ def StripEOL(lines):
37
+ """@brief Strip the end of line characters from the list of lines of text.
38
+ @param lines A list of lines of text.
39
+ @return The same list of lines of text with EOL removed."""
40
+ noEOLLines = []
41
+ for l in lines:
42
+ l = l.rstrip("\n")
43
+ l = l.rstrip("\r")
44
+ noEOLLines.append(l)
45
+ return noEOLLines
46
+
47
+ def __init__(self):
48
+ super(ExtendedSSHClient, self).__init__()
49
+ self.set_missing_host_key_policy(AutoAddPolicy())
50
+
51
+ def __exec_command(self, command, bufsize=-1, timeout=None):
52
+ """
53
+ Execute a command on the SSH server. A new L{Channel} is opened and
54
+ the requested command is executed. The command's input and output
55
+ streams are returned as python C{file}-like objects representing
56
+ stdin, stdout, and stderr.
57
+
58
+ @param command: the command to execute
59
+ @type command: str
60
+ @param bufsize: interpreted the same way as by the built-in C{file()} function in python
61
+ @type bufsize: int
62
+ @param timeout The timeout value in seconds or None if no command timeout.
63
+ @param type None or float
64
+ @return: the stdin, stdout, and stderr of the executing command
65
+ @rtype: tuple(L{ChannelFile}, L{ChannelFile}, L{ChannelFile})
66
+
67
+ @raise SSHException: if the server fails to execute the command
68
+ """
69
+ chan = self._transport.open_session(timeout=timeout)
70
+ chan.settimeout(timeout)
71
+ chan.exec_command(command)
72
+ stdin = chan.makefile('wb', bufsize)
73
+ stdout = chan.makefile('rb', bufsize)
74
+ stderr = chan.makefile_stderr('rb', bufsize)
75
+ exitStatus = chan.recv_exit_status()
76
+ return stdin, stdout, stderr, exitStatus
77
+
78
+ def startCmd(self, command ):
79
+ """
80
+ Start executing a executing a command on the SSH server. A new L{Channel} is opened and
81
+ the requested command is executed.
82
+
83
+ @param command: The command to execute.
84
+ @return The SSH channel session with the command executing in it (it is up to the caller to read data from this
85
+ channel in a timely manner).
86
+ @raise SSHException: if the server fails to execute the command
87
+ """
88
+ chan = self._transport.open_session()
89
+ chan.exec_command(command)
90
+ return chan
91
+
92
+ def runCmd(self, cmd, throwError=True, timeout=None):
93
+ """@brief Run a command over an ssh session
94
+ @param cmd The command to execute
95
+ @param throwError If True and the command fails throw an error. If False and the command fail return the erorr code etc.
96
+ If throwError is True and the exit status of the command is not 0 then an SSHError will be thrown
97
+ @param timeout The timeout value in seconds or None (default) if no command timeout is required.
98
+ @return A tuple with the following three elements if the command executes
99
+ 0 - the return code/exit status of the command
100
+ 1 - Lines of text from stdout
101
+ 2 - lines of text from stderr
102
+ If the command fails to execute then None is returned.
103
+ """
104
+ try:
105
+ stdin, stdout, stderr, exitStatus = self.__exec_command(cmd,timeout=timeout)
106
+ if throwError and exitStatus != 0:
107
+ errorText = stderr.read()
108
+ if len(errorText) > 0:
109
+ raise SSHError(errorText)
110
+ raise SSHError("The cmd '%s' return the error code: %d" % (cmd, exitStatus))
111
+ return [exitStatus, ExtendedSSHClient.GetLines(stdout.read()), ExtendedSSHClient.GetLines(stderr.read())]
112
+ except:
113
+ if throwError:
114
+ raise
115
+ return None
116
+
117
+
118
+
119
+ class SSH(object):
120
+ """@brief responsible for connecting an ssh connection, excuting commands."""
121
+
122
+ PRIVATE_KEY_FILE_LIST = ["id_rsa", "id_dsa", 'id_ecdsa']
123
+ PUBLIC_KEY_FILE_LIST = ["id_rsa.pub", "id_dsa.pub", 'id_ecdsa.pub']
124
+ LOCAL_SSH_CONFIG_PATH = os.path.join(os.path.expanduser("~"), ".ssh")
125
+ DEFAULT_REMOTE_SSH_CONFIG_FOLDER = "~/.ssh"
126
+ DEFAULT_REMOTE_SSH_AUTH_KEYS_FILE = "~/.ssh/authorized_keys"
127
+ DROPBEAR_DIR = "/etc/dropbear"
128
+ DROPBEAR_AUTH_KEYS_FILE = "%s/authorized_keys" % (DROPBEAR_DIR)
129
+ SSH_COPY_PROG = "/usr/bin/ssh-copy-id"
130
+ SERVER_AUTHORISED_KEYS_FILE = "~/.ssh/authorized_keys"
131
+ DEFAULT_SSH_CONNECTION_TIMEOUT = 20
132
+
133
+ @staticmethod
134
+ def AddKey(privateKey):
135
+ """@brief Add a key to the list of keys that will be tried when connecting to an ssh server.
136
+ This method should be called before instantiating the SSH class. If a key is
137
+ already present it won't be added to avoid duplicates.
138
+ @param privateKey The name of the private key file in the ~/.ssh folder. The public
139
+ key file must have the same name with a .pub extension."""
140
+ if privateKey not in SSH.PRIVATE_KEY_FILE_LIST:
141
+ SSH.PRIVATE_KEY_FILE_LIST.insert(0, privateKey)
142
+
143
+ if privateKey not in SSH.PUBLIC_KEY_FILE_LIST:
144
+ SSH.PUBLIC_KEY_FILE_LIST.insert(0, privateKey+".pub")
145
+
146
+ @staticmethod
147
+ def GetPublicKeyFile():
148
+ """@brief Get the public key file from the <HOME>/.ssh"""
149
+ homeFolder = SSH.LOCAL_SSH_CONFIG_PATH
150
+ if not os.path.isdir(homeFolder):
151
+ username = getuser()
152
+ if username == 'root':
153
+ homeFolder = '/root/.ssh'
154
+ else:
155
+ homeFolder = '/home/%s/.ssh' % (username)
156
+
157
+ for key in SSH.PUBLIC_KEY_FILE_LIST:
158
+ keyFile = os.path.join(homeFolder, key)
159
+ if os.path.isfile(keyFile):
160
+ return keyFile
161
+
162
+ raise SSHError("Unable to find a public key file. Please use the 'ssh-keygen -t rsa' command to generate a key pair.")
163
+
164
+ @staticmethod
165
+ def GetPublicKeyFileList():
166
+ """@brief Get a list of public key files that exist locally in <HOME>/.ssh"""
167
+ homeFolder = SSH.LOCAL_SSH_CONFIG_PATH
168
+ if not os.path.isdir(homeFolder):
169
+ username = getuser()
170
+ if username == 'root':
171
+ homeFolder = '/root/.ssh'
172
+ else:
173
+ homeFolder = '/home/%s/.ssh' % (username)
174
+
175
+ keyFileList = []
176
+ for key in SSH.PUBLIC_KEY_FILE_LIST:
177
+ keyFile = os.path.join(homeFolder, key)
178
+ if os.path.isfile(keyFile):
179
+ if keyFile not in keyFileList:
180
+ keyFileList.append(keyFile)
181
+
182
+ return keyFileList
183
+
184
+ @staticmethod
185
+ def GetPrivateKeyFile():
186
+ """@brief Get the private key file from the <HOME>/.ssh"""
187
+ homeFolder = SSH.LOCAL_SSH_CONFIG_PATH
188
+ if not os.path.isdir(homeFolder):
189
+ username = getuser()
190
+ if username == 'root':
191
+ homeFolder = '/root/.ssh'
192
+ else:
193
+ homeFolder = '/home/%s/.ssh' % (username)
194
+
195
+ for key in SSH.PRIVATE_KEY_FILE_LIST:
196
+ keyFile = os.path.join(homeFolder, key)
197
+ if os.path.isfile(keyFile):
198
+ return keyFile
199
+
200
+ raise SSHError("Unable to find a public key file. Please use the 'ssh-keygen -t rsa' command to generate a key pair.")
201
+
202
+ @staticmethod
203
+ def GetPrivateKeyFileList():
204
+ """@brief Get a list of private key files that exist locally in <HOME>/.ssh"""
205
+ homeFolder = SSH.LOCAL_SSH_CONFIG_PATH
206
+ if not os.path.isdir(homeFolder):
207
+ username = getuser()
208
+ if username == 'root':
209
+ homeFolder = '/root/.ssh'
210
+ else:
211
+ homeFolder = '/home/%s/.ssh' % (username)
212
+
213
+ keyFileList = []
214
+ for key in SSH.PRIVATE_KEY_FILE_LIST:
215
+ keyFile = os.path.join(homeFolder, key)
216
+ if os.path.isfile(keyFile):
217
+ if keyFile not in keyFileList:
218
+ keyFileList.append(keyFile)
219
+ return keyFileList
220
+
221
+ @staticmethod
222
+ def GetPublicKey():
223
+ """@brief Get the public ssh key from the local machine
224
+ @return The public key."""
225
+ pubKeyFile = SSH.GetPublicKeyFile()
226
+
227
+ fd = open(pubKeyFile, 'r')
228
+ lines = fd.readlines()
229
+ fd.close()
230
+
231
+ if len(lines) < 1:
232
+ raise SSHError("No public key text found in the %s file on the local computer." % (pubKeyFile))
233
+
234
+ publicKey = lines[0]
235
+ publicKey = publicKey.strip('\n')
236
+ publicKey = publicKey.strip('\r')
237
+ publicKey = publicKey.strip()
238
+ return publicKey
239
+
240
+ @staticmethod
241
+ def GetSSHKeyAttributes(authKey):
242
+ """@brief Extract the following from an ssh key and return them in a tuple.
243
+ @return a tuple containing
244
+
245
+ hostname
246
+ username
247
+ keytype
248
+ key
249
+
250
+ If unable to extract the above attributes then None is returned.
251
+ """
252
+ hostname = None
253
+
254
+ elems = authKey.split()
255
+ if len(elems) > 2:
256
+ keytype = elems[0]
257
+ key = elems[1]
258
+ tmpElems = elems[2].split("@")
259
+ if len(tmpElems) > 1:
260
+ username = tmpElems[0]
261
+ hostname = tmpElems[1]
262
+ else:
263
+ username = elems[2]
264
+ hostname = "?"
265
+
266
+ if hostname != None:
267
+ return (hostname, username, keytype, key)
268
+ return (None, None, None, None)
269
+
270
+ def __init__(self, host, username, password=None, useCompression=True, port=22, uio=None, privateKeyFile = None):
271
+ """@brief Constructor
272
+ @param host The SSH hostname
273
+ @param username The ssh username
274
+ @param password The ssh password (default=None)
275
+ @param useCompression If True then use compression on the ssh session (default=True)
276
+ @param port The ssh port number (default = 22)
277
+ @param uio A UIO instance (default=None)
278
+ @param privateKeyFile The private ssh keyfile (default=None=Use default private keyfile)
279
+ """
280
+ self._host = host
281
+ self._port = port
282
+ self._username = username
283
+ self._localAddress = None
284
+ self.useCompression = useCompression
285
+ self._password = password
286
+ self._uio = uio
287
+ if privateKeyFile:
288
+ SSH.AddKey(privateKeyFile)
289
+
290
+ self._ssh = ExtendedSSHClient()
291
+ logging.getLogger("paramiko").setLevel(logging.WARNING)
292
+ self._ssh.set_missing_host_key_policy(AutoAddPolicy())
293
+ self._sftp = None
294
+
295
+ def _info(self, text):
296
+ """@brief Present an info level message to the user.
297
+ @param text The text to be presented to the user."""
298
+ if self._uio:
299
+ self._uio.info(text)
300
+
301
+ def _warn(self, text):
302
+ """@brief Present an warning level message to the user.
303
+ @param text The text to be presented to the user."""
304
+ if self._uio:
305
+ self._uio.warn(text)
306
+
307
+ def _debug(self, text):
308
+ """@brief Present debug level message to the user.
309
+ @param text The text to be presented to the user."""
310
+ if self._uio:
311
+ self._uio.debug(text)
312
+
313
+ def isConnected(self):
314
+ """@brief Check if a connection is active.
315
+ @return True if connected."""
316
+ connected = False
317
+ if self._ssh and self._ssh.get_transport() is not None:
318
+ connected = self._ssh.get_transport().is_active()
319
+ return connected
320
+
321
+ def _connect(self, connectSFTPSession=False, timeout=DEFAULT_SSH_CONNECTION_TIMEOUT):
322
+ """@brief Connect the ssh connection
323
+ @param connectSFTPSession If True then just after the ssh connection
324
+ is built an SFT session will be built ready for file transfer.
325
+ @param timeout The connection timeout in seconds.
326
+ @return a ref to the SSHClient object"""
327
+ if not self._ssh:
328
+ self._ssh = ExtendedSSHClient()
329
+
330
+ connected = False
331
+ for privateKeyFile in SSH.GetPrivateKeyFileList():
332
+ if os.path.isfile(privateKeyFile):
333
+ msg = "Trying private key {}".format(privateKeyFile)
334
+ self._debug(msg)
335
+ try:
336
+ #Define the ssh config
337
+ cfg = {
338
+ 'hostname': self._host,
339
+ 'port': self._port,
340
+ 'timeout': timeout,
341
+ 'username': self._username,
342
+ 'key_filename': privateKeyFile,
343
+ # This used to be required or else loging in without the password would fail.
344
+ # This is no longer true for the latest paramiko as of 8 Mar 2024.
345
+ # Therefore this workaround has been removed but shown in place in case of future issues.
346
+ #'disabled_algorithms': dict(pubkeys=['rsa-sha2-256', 'rsa-sha2-512'])
347
+ }
348
+ # If we have a password then add this to the config
349
+ if self._password and len(self._password) > 0:
350
+ cfg['password']=self._password
351
+ self._ssh.connect(**cfg)
352
+ connected = True
353
+ break
354
+
355
+ # Ensure we throw an exception in the event of authencication failure as this ensures
356
+ # that the code to triggers the code to allow the user to copy thier public ssh key to
357
+ # the server in order that future logins are passwordless.
358
+ except AuthenticationException:
359
+ raise
360
+
361
+ except:
362
+ pass
363
+
364
+ if not connected:
365
+ raise Exception("Failed to connect to the SSH server {}@{}.".format(self._username, self._host))
366
+
367
+ # It can be usefull to know what local IP address was used to reach the ssh server
368
+ self._localAddress = self._ssh.get_transport().sock.getsockname()[0]
369
+ self._ssh.get_transport().use_compression(self.useCompression)
370
+ if connectSFTPSession:
371
+ self._sftp = SFTPClient.from_transport( self._ssh.get_transport() )
372
+ return self._ssh
373
+
374
+ def getLocalAddress(self):
375
+ """@brief Get the local IP address of the network interface used to connect to the ssh server"""
376
+ return self._localAddress
377
+
378
+ def getSSHClient(self):
379
+ """@brief return a ref to the SSHClient object"""
380
+ return self._ssh
381
+
382
+ def close(self):
383
+ """@brief Close an open ssh connection."""
384
+ if self._ssh:
385
+ self._ssh.close()
386
+ self._ssh = None
387
+
388
+ if self._sftp:
389
+ self._sftp = None
390
+
391
+ def getTransport(self):
392
+ """@brief Get the ssh transport object. Should only be
393
+ called when the ssh session is connected.
394
+ @return The ssh transport object."""
395
+ return self._ssh.get_transport()
396
+
397
+ def runCmd(self, cmd, throwError=True, timeout=None):
398
+ """@brief Run a command over an ssh session
399
+ @param cmd The command to execute
400
+ @param throwError If True and the command fails throw an error. If False and the command fail return the erorr code etc.
401
+ If throwError is True and the exit status of the command is not 0 then an SSHError will be thrown
402
+ @param timeout The timeout value in seconds or None (default) if no command timeout is required.
403
+ @return A tuple with the following threee elements
404
+ 0 - the return code/exit status of the command
405
+ 1 - Lines of text from stdout
406
+ 2 - lines of text from stderr
407
+ """
408
+ return self._ssh.runCmd(cmd, throwError=throwError, timeout=timeout)
409
+
410
+ def startCmd(self, cmd):
411
+ """@brief Start executing a command. This will return after starting the command and before the command has completed.
412
+ The following methods maybe called to interrogate the command executions
413
+ channel.exit_status_ready()
414
+ channel.recv_ready()
415
+ channel.channel.recv(len(channel.in_buffer))
416
+ channel.recv_stderr_ready()
417
+ channel.recv_stderr(len(channel.in_stderr_buffer))
418
+
419
+ When the command is complete the caller should call channel.close()
420
+ @return A channel instance in which the command is executing.
421
+ """
422
+ return self._ssh.startCmd(cmd)
423
+
424
+ def connect(self, enableAutoLoginSetup=False, connectSFTPSession=False, timeout=DEFAULT_SSH_CONNECTION_TIMEOUT):
425
+ """@brief Connect the ssh connection
426
+ @param enableAutoLoginSetup If True and auto login is not setup the
427
+ user is prompted for the password and the local ssh public key
428
+ is copied to the server.
429
+ @param connectSFTPSession If True then just after the ssh connection
430
+ is built an SFT session will be built ready for file transfer.
431
+ @param timeout The connection timeout in seconds.
432
+ @return True if connected without setting up auto login. False if auto
433
+ login was setup for connection to succeed."""
434
+ setupAutoLogin = False
435
+ try:
436
+ self._connect(connectSFTPSession=connectSFTPSession, timeout=timeout)
437
+
438
+ except AuthenticationException:
439
+ self._setupAutologin(timeout)
440
+ self._connect(connectSFTPSession=connectSFTPSession, timeout=timeout)
441
+ setupAutoLogin = True
442
+ return setupAutoLogin
443
+
444
+ def _setupAutologin(self, timeout=DEFAULT_SSH_CONNECTION_TIMEOUT):
445
+ """@brief Setup autologin on the ssh server.
446
+ @param timeout The connection timeout in seconds."""
447
+ self._warn("Auto login to the ssh server failed authentication.")
448
+ self._info("Copying the local public ssh key to the ssh server for automatic login.")
449
+ self._info("Please enter the ssh server ({}) password for the user: {}".format(self._host, self._username))
450
+
451
+ self._password = self._uio.getPassword("SSH password: ")
452
+
453
+ self._connect()
454
+
455
+ self._ensureAutoLogin()
456
+
457
+ self.close()
458
+
459
+ self._info("Local public ssh key copied to the ssh server.")
460
+
461
+ def _ensureAutoLogin(self):
462
+ """@brief Ensure that ssh auto login is enabled."""
463
+
464
+ localPublicKey = SSH.GetPublicKey()
465
+ _hostname, _username, _keytype, _ = SSH.GetSSHKeyAttributes(localPublicKey)
466
+ if _hostname == None:
467
+ _hostname = socket.gethostname()
468
+ if _username == None:
469
+ _username = getpass.getuser()
470
+
471
+ self._info("Using key: %s@%s" % (_username, _hostname))
472
+ remoteAuthorisedKeys = self.getRemoteAuthorisedKeys()
473
+ # Check to see if the remote authorised keys contains the local public key
474
+ updateAuthKeys = True
475
+ for remoteAuthorisedKeys in remoteAuthorisedKeys:
476
+ if remoteAuthorisedKeys.find(localPublicKey) == 0:
477
+ updateAuthKeys = False
478
+ break
479
+
480
+ if updateAuthKeys:
481
+ remoteAuthKeysFile = self.updateAuthorisedKeys(localPublicKey)
482
+ self._info("Updated the remote %s file from the local %s file." % (remoteAuthKeysFile, SSH.GetPublicKeyFile()))
483
+ else:
484
+ self._info("The server already has the local ssh key (%s) in its authorized_key file." % (SSH.GetPublicKeyFile()))
485
+
486
+ def updateAuthorisedKeys(self, publicKey):
487
+ """Update the authorised keys file on the remote ssh server with the
488
+ public ssh key"""
489
+ authKeysFile = self.getRemoteAuthorisedKeyFile()
490
+ cmd = "test -d %s" % (authKeysFile)
491
+ rc, stdoutlines, stderrlines = self.runCmd(cmd, throwError=False)
492
+ if rc == 0:
493
+ # If this is a dir with nothing in it, delete it and create an empty
494
+ # authorized_keys file.
495
+ self.runCmd("rmdir %s" % (authKeysFile))
496
+ self.runCmd("touch %s" % (authKeysFile))
497
+
498
+ self.runCmd("echo \"%s\" >> %s" % (publicKey, authKeysFile))
499
+ self.runCmd("chmod 600 %s" % (authKeysFile))
500
+ return authKeysFile
501
+
502
+ def getRemoteAuthorisedKeys(self):
503
+ """Get the remote authorised keys file over the ssh connection."""
504
+ authKeysFile = self.getRemoteAuthorisedKeyFile()
505
+ cmd = "test -f %s" % (authKeysFile)
506
+ rc, stdoutLines, stderrLines = self.runCmd(cmd, throwError=False)
507
+ if rc != 0:
508
+ # Auth keys file not found, attempt to create an empty one.
509
+ cmd = "touch %s" % (authKeysFile)
510
+ rc, stdoutLines, stderrLines = self.runCmd(cmd, throwError=False)
511
+ cmd = "test -f %s" % (authKeysFile)
512
+ rc, stdoutLines, stderrLines = self.runCmd(cmd, throwError=False)
513
+ if rc != 0:
514
+ raise SSHError("!!! Server auth keys file not found (%s). Failed to create it." % (authKeysFile))
515
+
516
+ rc, stdoutLines, stderrLines = self.runCmd("cat %s" % (authKeysFile), throwError=False)
517
+
518
+ # Ensure we only return non empty lines
519
+ authKeyLines = []
520
+ for l in stdoutLines:
521
+ if len(l.strip()) > 0:
522
+ authKeyLines.append(l)
523
+ return authKeyLines
524
+
525
+ def getRemoteAuthorisedKeyFile(self):
526
+ """@brief Return the remote authorised key file for the current ssh connection."""
527
+
528
+ authKeysFile = SSH.DEFAULT_REMOTE_SSH_AUTH_KEYS_FILE
529
+
530
+ cmd = "test -d %s" % (SSH.DROPBEAR_DIR)
531
+ rc, stdoutLines, stderrLines = self.runCmd(cmd, throwError=False)
532
+ # If the ssh server uses dropbear
533
+ if rc == 0:
534
+ authKeysFile = SSH.DROPBEAR_AUTH_KEYS_FILE
535
+ else:
536
+ # Check for the users ssh config folder
537
+ cmd = "test -d %s" % (SSH.DEFAULT_REMOTE_SSH_CONFIG_FOLDER)
538
+ rc, stdoutLines, stderrLines = self.runCmd(cmd, throwError=False)
539
+ if rc != 0:
540
+ # If the remote ssh config folder does not exist, create it
541
+ cmd = f"mkdir {SSH.DEFAULT_REMOTE_SSH_CONFIG_FOLDER}"
542
+ rc, stdoutLines, stderrLines = self.runCmd(cmd, throwError=True)
543
+
544
+ return authKeysFile
545
+
546
+ def getFile(self, remoteFilePath, localFilePath ):
547
+ """@brief Get a file from the sftp server
548
+ @param remoteFilePath The remote file on the ssh server.
549
+ @param localFilePath The path of the file after it's been received"""
550
+ if self._sftp:
551
+ self._sftp.get(remoteFilePath,localFilePath)
552
+ else:
553
+ raise SSHError("SFTP not connected.")
554
+
555
+ def putFile(self, localFilePath, remoteFilePath ):
556
+ """@brief Get a file from the sftp server
557
+ @param localFilePath The path of the file after it's been received
558
+ @param remoteFilePath The remote file on the ssh server."""
559
+ if self._sftp:
560
+ self._sftp.put(localFilePath, remoteFilePath)
561
+ else:
562
+ raise SSHError("SFTP not connected.")
563
+
564
+ def getAuthKeyBackupFile(self, maxBackupFileCount=10):
565
+ """@brief Get the name of the backup name for the authorised keys file.
566
+ @param maxBackupFileCount The maximum number of backup files to keep.
567
+ @return None
568
+ - We create up to maxBackupFileCount backup files.
569
+ - Once all the backup files have been created we always replace the oldest
570
+ backup file.
571
+ - The backup files have the suffix .backup1, .backup2 etc.
572
+ """
573
+ authKeysFile = self.getRemoteAuthorisedKeyFile()
574
+ authKeyBackupFilePart = "%s.backup" % (authKeysFile)
575
+ suffixNum = 1
576
+ while True:
577
+ authKeyBackupFileName = "%s%d" % (authKeyBackupFilePart, suffixNum)
578
+ rc, stdoutLines, stderrLines = self.runCmd("test -f %s" % (authKeyBackupFileName), throwError=False)
579
+ if rc != 0:
580
+ return authKeyBackupFileName
581
+ # If all the backup files have been created
582
+ if suffixNum >= maxBackupFileCount:
583
+ # List the files in creation order (oldest first)
584
+ rc, stdoutLines, stderrLines = self.runCmd("ls -ltr %s*" % (authKeyBackupFilePart), throwError=False)
585
+ if rc == 0:
586
+ if len(stdoutLines) > 0:
587
+ elems = stdoutLines[0].split()
588
+ if len(elems) > 0:
589
+ # Return the oldest file as the next backup filename so that we roll
590
+ # around the always replacing the oldest backup file.
591
+ backupfile = elems[len(elems) - 1]
592
+ return backupfile
593
+
594
+ raise SSHError("Unable to %s to %s. Please manually remove the backup files on the ssh server." % (
595
+ authKeysFile, authKeyBackupFilePart))
596
+
597
+ suffixNum = suffixNum + 1
598
+
599
+ def _getPublicKeyID(self, publicKey):
600
+ """@brief Get the public key ID string.
601
+ @param publicKey The ssh public key string"""
602
+ elems = publicKey.split()
603
+ if len(elems) == 3:
604
+ return elems[2]
605
+ raise SSHError("{} is an invalid public key.".format(publicKey) )
606
+
607
+ def removeAuthKey(self, publicKey):
608
+ """@brief Remove authorised keys from the server authorised keys file.
609
+ @param publicKeysForRemoval A list of public keys for removal."""
610
+ remove = False
611
+ previousAuthKeysFile = self.getAuthKeyBackupFile()
612
+ publicKeyID = self._getPublicKeyID(publicKey)
613
+ self._info("Public ssh key ID: {}".format(publicKeyID))
614
+ authKeysFile = self.getRemoteAuthorisedKeyFile()
615
+ tmpAuthKeysFile = "%s.tmp" % (authKeysFile)
616
+ retCode, publicKeyList, _ = self.runCmd("cat {}".format(authKeysFile), throwError=False)
617
+ if retCode == 0:
618
+ newAuthKeysList = []
619
+ for publicKey in publicKeyList:
620
+ if publicKey.find(publicKeyID) >= 0:
621
+ self._info("Found {} public key on ssh server.".format(publicKeyID))
622
+ remove = True
623
+ else:
624
+ newAuthKeysList.append(publicKey)
625
+
626
+ if remove:
627
+ # Remove any pre existing tmp auth keys file
628
+ self.runCmd("rm -f %s" % (tmpAuthKeysFile), throwError=False)
629
+
630
+ # Create empty tmp auth keys file
631
+ self.runCmd("touch %s" % (tmpAuthKeysFile))
632
+
633
+ for newAuthKey in newAuthKeysList:
634
+ self.runCmd("echo \"%s\" >> %s" % (newAuthKey, tmpAuthKeysFile), throwError=False)
635
+
636
+ # Remove any pre existing previous auth keys file
637
+ self.runCmd("rm -f %s" % (previousAuthKeysFile), throwError=False)
638
+
639
+ # Move the current auth keys file to the old one and the tmp to the current one
640
+ self.runCmd("mv %s %s" % (authKeysFile, previousAuthKeysFile))
641
+ self.runCmd("mv %s %s" % (tmpAuthKeysFile, authKeysFile))
642
+ self._info("Removed {} public key from the ssh server.".format(publicKeyID))
643
+ else:
644
+ self._info("{} public key not found in ssh server authorised keys file.".format(publicKeyID))
645
+
646
+ return remove
647
+
648
+ class SSHTunnelManager(object):
649
+ """@brief Responsible for setting up, tearing down and maintaining lists of
650
+ SSH port forwarding and ssh reverse port forwarding connections."""
651
+
652
+ RX_BUFFER_SIZE = 65535
653
+
654
+ def __init__(self, uio, ssh, useCompression):
655
+ """@brief Constructor
656
+ @param uio UIO instance
657
+ @param ssh An instance of SSHClient that has previously been
658
+ connected to an ssh server"""
659
+ self._uio = uio
660
+ self._ssh = ssh
661
+ self._useCompression = useCompression
662
+ if not self._ssh.getTransport().is_active():
663
+ raise SSHError("!!! The ssh connection is not connected !!!")
664
+
665
+ self._forwardingServerList = []
666
+ self._reverseSShDict = {}
667
+
668
+ def _info(self, text):
669
+ """@brief Present an info level message to the user.
670
+ @param text The text to be presented to the user."""
671
+ if self._uio:
672
+ self._uio.info(text)
673
+
674
+ def _warn(self, text):
675
+ """@brief Present an warning level message to the user.
676
+ @param text The text to be presented to the user."""
677
+ if self._uio:
678
+ self._uio.warn(text)
679
+
680
+ def _error(self, text):
681
+ """@brief Present an error level message to the user.
682
+ @param text The text to be presented to the user."""
683
+ if self._uio:
684
+ self._uio.error(text)
685
+
686
+ def startFwdSSHTunnel(self, serverPort, destHost, destPort, serverBindAddress=''):
687
+ """@brief Start an ssh port forwarding tunnel. This is a non blocking method.
688
+ A separate thread will be started to handle data transfer over the
689
+ ssh forwarding connection.
690
+ @param serverPort The TCP server port. On a port forwarding connection
691
+ the TCP server runs on the src end of the ssh connection.
692
+ This is the machine that this python code is executing on.
693
+ @param destHost The host address of the tunnel destination at the remote
694
+ end of the ssh connection.
695
+ @param destPort The host TCP port of the tunnel destination at the remote
696
+ end of the ssh connection.
697
+ @param serverBindAddress The server address to bind to."""
698
+ self._info("Forwarding local TCP server port (%d) to %s:%d on the remote end of the ssh connection." % (
699
+ serverPort, destHost, destPort))
700
+ transport = self._ssh.getTransport()
701
+
702
+ class SubHander(ForwardingHandler):
703
+ chain_host = destHost
704
+ chain_port = destPort
705
+ ssh_transport = transport
706
+ ssh_transport.use_compression(self._useCompression)
707
+ uo = self._uio
708
+
709
+ forwardingServer = ForwardingServer((serverBindAddress, serverPort), SubHander)
710
+ self._forwardingServerList.append(forwardingServer)
711
+ newThread = threading.Thread(target=forwardingServer.serve_forever)
712
+ newThread.daemon = True
713
+ newThread.start()
714
+
715
+ def startFwdTunnel(self, serverBindAddress, serverPort, destHost, destPort):
716
+ """@brief Another method to start a forward SSH tunnel. startFwdSSHTunnel() was the original method.
717
+ We needed to keep this interface. However I wanted an interface to force the user to enter a bind
718
+ address on the server.
719
+ @param serverBindAddress The server address to bind to.
720
+ @param serverPort The TCP server port. On a port forwarding connection
721
+ the TCP server runs on the src end of the ssh connection.
722
+ This is the machine that this python code is executing on.
723
+ @param destHost The host address of the tunnel destination at the remote
724
+ end of the ssh connection.
725
+ @param destPort The host TCP port of the tunnel destination at the remote
726
+ end of the ssh connection."""
727
+ self.startFwdSSHTunnel(serverPort, destHost, destPort, serverBindAddress=serverBindAddress)
728
+
729
+ def stopFwdSSHTunnel(self, serverPort):
730
+ """@brief stop a previously started ssh port forwarding server
731
+ @param serverPort The TCP server port which is currently accepting
732
+ port forwarding connections on."""
733
+ for forwardingServer in self._forwardingServerList:
734
+ forwardingServerPort = forwardingServer.server_address[1]
735
+ if forwardingServerPort == serverPort:
736
+ forwardingServer.shutdown()
737
+ forwardingServer.server_close()
738
+ self._info("Shutdown ssh port forwarding connection using local server port %d." % (serverPort))
739
+
740
+ def stopAllFwdSSHTunnels(self):
741
+ """@brief Stop all previously started ssh port forwarding servers.."""
742
+ for forwardingServer in self._forwardingServerList:
743
+ forwardingServer.shutdown()
744
+ forwardingServer.server_close()
745
+ self._info("Shutdown ssh port forwarding on %s." % (str(forwardingServer.server_address)))
746
+
747
+ def startRevSSHTunnel(self, serverPort, destHost, destPort, serverBindAddress=''):
748
+ """@brief Start an ssh reverse port forwarding tunnel
749
+ @param serverPort The TCP server port. On a reverse port forwarding connection
750
+ the TCP server runs on the dest end of the ssh connection.
751
+ This is the machine at the remote end of the ssh connection.
752
+ @param destHost The host address of the tunnel destination at the local
753
+ end of the ssh connection.
754
+ @param destPort The host TCP port of the tunnel destination at the local
755
+ end of the ssh connection.
756
+ @param serverBindAddress The server address to bind to."""
757
+ self._info("Forwarding (reverse) Remote TCP server port (%d) to %s:%d on this end of the ssh connection." % (
758
+ serverPort, destHost, destPort))
759
+ # We add the None refs as the placeholders will be used later
760
+ chan = None
761
+ sock = None
762
+ self._reverseSShDict[serverPort] = (destHost, destPort, chan, sock)
763
+
764
+ self._ssh.getTransport().use_compression(self._useCompression)
765
+ self._ssh.getTransport().request_port_forward(serverBindAddress, serverPort, handler=self._startReverseForwardingHandler)
766
+
767
+ def startRevTunnel(self, serverBindAddress, serverPort, destHost, destPort):
768
+ """@brief Another method to start a reverse SSH tunnerl. startRevSSHTunnel() was the original method.
769
+ We needed to keep this interface. However I wanted an interface to force the user to enter a bind
770
+ address on the server.
771
+ @param serverBindAddress The server address to bind to.
772
+ @param serverPort The TCP server port. On a reverse port forwarding connection
773
+ the TCP server runs on the dest end of the ssh connection.
774
+ This is the machine at the remote end of the ssh connection.
775
+ @param destHost The host address of the tunnel destination at the local
776
+ end of the ssh connection.
777
+ @param destPort The host TCP port of the tunnel destination at the local
778
+ end of the ssh connection."""
779
+
780
+ self.startRevSSHTunnel(serverPort, destHost, destPort, serverBindAddress=serverBindAddress)
781
+
782
+ def stopRevSSHTunnel(self, serverPort):
783
+ """@brief stop a previously started reverse ssh port forwarding server
784
+ @param serverPort The TCP server port which is currently accepting
785
+ port forwarding connections on."""
786
+ if serverPort in self._reverseSShDict:
787
+ revSSHParams = self._reverseSShDict[serverPort]
788
+ chan = revSSHParams[2]
789
+ sock = revSSHParams[3]
790
+
791
+ if chan:
792
+ chan.close()
793
+
794
+ if sock:
795
+ sock.close()
796
+
797
+ self._info("Shutdown reverse ssh port forwarding connection using remote server port %d." % (serverPort))
798
+
799
+ def stopAllRevSSHTunnels(self):
800
+ """@brief Stop all previously started reverse ssh port forwarding servers."""
801
+ for key in list(self._reverseSShDict.keys()):
802
+ revSSHParams = self._reverseSShDict[key]
803
+ chan = revSSHParams[2]
804
+ sock = revSSHParams[3]
805
+
806
+ if chan:
807
+ chan.close()
808
+
809
+ if sock:
810
+ sock.close()
811
+
812
+ self._info("Shutdown reverse ssh port forwarding connection using remote server port %d." % (key))
813
+
814
+ def stopAllSSHTunnels(self):
815
+ """@brief Stop all ssh tunnels."""
816
+ self.stopAllFwdSSHTunnels()
817
+ self.stopAllRevSSHTunnels()
818
+
819
+ # !!! The following methods are internal and should noit be called externally.
820
+ def _getDestination(self, serverPort):
821
+ """@brief Get destination (address and port) for the given server port.
822
+ @param serverPort The TCP server port on the ssh server."""
823
+ if serverPort in self._reverseSShDict:
824
+ revSSHParams = self._reverseSShDict[serverPort]
825
+ return (revSSHParams[0], revSSHParams[1])
826
+
827
+ return None
828
+
829
+ def _startReverseForwardingHandler(self, chan, xxx_todo_changeme, xxx_todo_changeme1):
830
+ """@brief Called when a channel is connected in order to start a handler thread fot it."""
831
+ (origin_addr, origin_port) = xxx_todo_changeme
832
+ (server_addr, serverPort) = xxx_todo_changeme1
833
+ destHost, destPort = self._getDestination(serverPort)
834
+
835
+ hThread = threading.Thread(target=self._reverseForwardingHandler, args=(chan, serverPort, destHost, destPort))
836
+ hThread.setDaemon(True)
837
+ hThread.start()
838
+
839
+ def _reverseForwardingHandler(self, chan, serverPort, destHost, destPort):
840
+ """@brief Handle a reverse ssh forwarding connection.
841
+ @param chan A connected channel over an ssh connection.
842
+ @param serverPort The server port (on remote ssh server) from where the reverse ssh connection originated.
843
+ @param destHost The destination host address.
844
+ @param destPort The destination port address."""
845
+
846
+ sock = socket.socket()
847
+ # Add references to the chnl and sock so that they can be closed if required
848
+ self._reverseSShDict[serverPort] = (destHost, destPort, chan, sock)
849
+ try:
850
+ sock.connect((destHost, destPort))
851
+ except Exception as e:
852
+ self._error('Forwarding (reverse) request to %s:%d failed: %r' % (destHost, destPort, e))
853
+ return
854
+
855
+ self._info('Connected! Reverse tunnel open %r -> %r -> %r' % (chan.origin_addr,
856
+ chan.getpeername(), (destHost, destPort)))
857
+ while True:
858
+ r, w, x = select.select([sock, chan], [], [])
859
+ if sock in r:
860
+ data = sock.recv(SSHTunnelManager.RX_BUFFER_SIZE)
861
+ if len(data) == 0:
862
+ break
863
+ try:
864
+ chan.send(data)
865
+ except:
866
+ break
867
+ if chan in r:
868
+ data = chan.recv(SSHTunnelManager.RX_BUFFER_SIZE)
869
+ if len(data) == 0:
870
+ break
871
+ try:
872
+ sock.send(data)
873
+ except:
874
+ break
875
+ chan.close()
876
+ sock.close()
877
+ self._info('Tunnel closed from server port %d' % (serverPort))
878
+
879
+
880
+ class ForwardingServer(socketserver.ThreadingTCPServer):
881
+ """@brief Server responsible for ssh port forwarding"""
882
+ daemon_threads = True
883
+ allow_reuse_address = True
884
+
885
+
886
+ class ForwardingHandler(socketserver.BaseRequestHandler):
887
+ """@brief handler for ssh port forwarding connections."""
888
+
889
+ def _info(self, text):
890
+ """@brief Present an info level message to the user.
891
+ @param text The text to be presented to the user."""
892
+ if self.uo:
893
+ self.uo.info(text)
894
+
895
+ def _error(self, text):
896
+ """@brief Present an error level message to the user.
897
+ @param text The text to be presented to the user."""
898
+ if self.uo:
899
+ self.uo.error(text)
900
+
901
+ def handle(self):
902
+ try:
903
+ chan = self.ssh_transport.open_channel('direct-tcpip',
904
+ (self.chain_host, self.chain_port),
905
+ self.request.getpeername())
906
+ except Exception as e:
907
+ self._error('Incoming request to %s:%d failed: %s' % (self.chain_host,
908
+ self.chain_port,
909
+ repr(e)))
910
+ return
911
+ if chan is None:
912
+ self._error('Incoming request to %s:%d was rejected by the SSH server.' %
913
+ (self.chain_host, self.chain_port))
914
+ return
915
+
916
+ self._info('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(),
917
+ chan.getpeername(),
918
+ (self.chain_host, self.chain_port)))
919
+ while True:
920
+ r, w, x = select.select([self.request, chan], [], [])
921
+ if self.request in r:
922
+ data = self.request.recv(SSHTunnelManager.RX_BUFFER_SIZE)
923
+ if len(data) == 0:
924
+ break
925
+ chan.send(data)
926
+ if chan in r:
927
+ data = chan.recv(SSHTunnelManager.RX_BUFFER_SIZE)
928
+ if len(data) == 0:
929
+ break
930
+ self.request.send(data)
931
+
932
+ peername = self.request.getpeername()
933
+ chan.close()
934
+ self.request.close()
935
+ self._info('Tunnel closed from %r' % (peername,))