testgres 1.9.1__tar.gz → 1.9.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {testgres-1.9.1/testgres.egg-info → testgres-1.9.3}/PKG-INFO +3 -3
  2. {testgres-1.9.1 → testgres-1.9.3}/setup.py +5 -7
  3. {testgres-1.9.1 → testgres-1.9.3}/testgres/__init__.py +3 -1
  4. {testgres-1.9.1 → testgres-1.9.3}/testgres/exceptions.py +10 -1
  5. testgres-1.9.3/testgres/helpers/port_manager.py +40 -0
  6. {testgres-1.9.1 → testgres-1.9.3}/testgres/node.py +8 -6
  7. testgres-1.9.3/testgres/operations/__init__.py +0 -0
  8. {testgres-1.9.1 → testgres-1.9.3}/testgres/operations/local_ops.py +81 -61
  9. {testgres-1.9.1 → testgres-1.9.3}/testgres/operations/os_ops.py +7 -1
  10. {testgres-1.9.1 → testgres-1.9.3}/testgres/operations/remote_ops.py +78 -84
  11. {testgres-1.9.1 → testgres-1.9.3}/testgres/utils.py +9 -4
  12. {testgres-1.9.1 → testgres-1.9.3/testgres.egg-info}/PKG-INFO +3 -3
  13. {testgres-1.9.1 → testgres-1.9.3}/testgres.egg-info/SOURCES.txt +2 -0
  14. {testgres-1.9.1 → testgres-1.9.3}/testgres.egg-info/requires.txt +0 -2
  15. {testgres-1.9.1 → testgres-1.9.3}/tests/test_remote.py +4 -5
  16. {testgres-1.9.1 → testgres-1.9.3}/tests/test_simple.py +35 -12
  17. {testgres-1.9.1 → testgres-1.9.3}/tests/test_simple_remote.py +7 -7
  18. {testgres-1.9.1 → testgres-1.9.3}/LICENSE +0 -0
  19. {testgres-1.9.1 → testgres-1.9.3}/MANIFEST.in +0 -0
  20. {testgres-1.9.1 → testgres-1.9.3}/README.md +0 -0
  21. {testgres-1.9.1 → testgres-1.9.3}/setup.cfg +0 -0
  22. {testgres-1.9.1 → testgres-1.9.3}/testgres/api.py +0 -0
  23. {testgres-1.9.1 → testgres-1.9.3}/testgres/backup.py +0 -0
  24. {testgres-1.9.1 → testgres-1.9.3}/testgres/cache.py +0 -0
  25. {testgres-1.9.1 → testgres-1.9.3}/testgres/config.py +0 -0
  26. {testgres-1.9.1 → testgres-1.9.3}/testgres/connection.py +0 -0
  27. {testgres-1.9.1 → testgres-1.9.3}/testgres/consts.py +0 -0
  28. {testgres-1.9.1 → testgres-1.9.3}/testgres/decorators.py +0 -0
  29. {testgres-1.9.1 → testgres-1.9.3}/testgres/defaults.py +0 -0
  30. {testgres-1.9.1 → testgres-1.9.3}/testgres/enums.py +0 -0
  31. {testgres-1.9.1/testgres/operations → testgres-1.9.3/testgres/helpers}/__init__.py +0 -0
  32. {testgres-1.9.1 → testgres-1.9.3}/testgres/logger.py +0 -0
  33. {testgres-1.9.1 → testgres-1.9.3}/testgres/pubsub.py +0 -0
  34. {testgres-1.9.1 → testgres-1.9.3}/testgres/standby.py +0 -0
  35. {testgres-1.9.1 → testgres-1.9.3}/testgres.egg-info/dependency_links.txt +0 -0
  36. {testgres-1.9.1 → testgres-1.9.3}/testgres.egg-info/top_level.txt +0 -0
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: testgres
3
- Version: 1.9.1
3
+ Version: 1.9.3
4
4
  Summary: Testing utility for PostgreSQL and its extensions
5
5
  Home-page: https://github.com/postgrespro/testgres
6
- Author: Ildar Musin
7
- Author-email: zildermann@gmail.com
6
+ Author: Postgres Professional
7
+ Author-email: testgres@postgrespro.ru
8
8
  License: PostgreSQL
9
9
  Keywords: test,testing,postgresql
10
10
  Platform: UNKNOWN
@@ -11,9 +11,7 @@ install_requires = [
11
11
  "port-for>=0.4",
12
12
  "six>=1.9.0",
13
13
  "psutil",
14
- "packaging",
15
- "fabric",
16
- "sshtunnel"
14
+ "packaging"
17
15
  ]
18
16
 
19
17
  # Add compatibility enum class
@@ -29,16 +27,16 @@ with open('README.md', 'r') as f:
29
27
  readme = f.read()
30
28
 
31
29
  setup(
32
- version='1.9.1',
30
+ version='1.9.3',
33
31
  name='testgres',
34
- packages=['testgres', 'testgres.operations'],
32
+ packages=['testgres', 'testgres.operations', 'testgres.helpers'],
35
33
  description='Testing utility for PostgreSQL and its extensions',
36
34
  url='https://github.com/postgrespro/testgres',
37
35
  long_description=readme,
38
36
  long_description_content_type='text/markdown',
39
37
  license='PostgreSQL',
40
- author='Ildar Musin',
41
- author_email='zildermann@gmail.com',
38
+ author='Postgres Professional',
39
+ author_email='testgres@postgrespro.ru',
42
40
  keywords=['test', 'testing', 'postgresql'],
43
41
  install_requires=install_requires,
44
42
  classifiers=[],
@@ -52,6 +52,8 @@ from .operations.os_ops import OsOperations, ConnectionParams
52
52
  from .operations.local_ops import LocalOperations
53
53
  from .operations.remote_ops import RemoteOperations
54
54
 
55
+ from .helpers.port_manager import PortManager
56
+
55
57
  __all__ = [
56
58
  "get_new_node",
57
59
  "get_remote_node",
@@ -62,6 +64,6 @@ __all__ = [
62
64
  "XLogMethod", "IsolationLevel", "NodeStatus", "ProcessType", "DumpFormat",
63
65
  "PostgresNode", "NodeApp",
64
66
  "reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version",
65
- "First", "Any",
67
+ "First", "Any", "PortManager",
66
68
  "OsOperations", "LocalOperations", "RemoteOperations", "ConnectionParams"
67
69
  ]
@@ -32,7 +32,16 @@ class ExecUtilException(TestgresException):
32
32
  if self.out:
33
33
  msg.append(u'----\n{}'.format(self.out))
34
34
 
35
- return six.text_type('\n').join(msg)
35
+ return self.convert_and_join(msg)
36
+
37
+ @staticmethod
38
+ def convert_and_join(msg_list):
39
+ # Convert each byte element in the list to str
40
+ str_list = [six.text_type(item, 'utf-8') if isinstance(item, bytes) else six.text_type(item) for item in
41
+ msg_list]
42
+
43
+ # Join the list into a single string with the specified delimiter
44
+ return six.text_type('\n').join(str_list)
36
45
 
37
46
 
38
47
  @six.python_2_unicode_compatible
@@ -0,0 +1,40 @@
1
+ import socket
2
+ import random
3
+ from typing import Set, Iterable, Optional
4
+
5
+
6
+ class PortForException(Exception):
7
+ pass
8
+
9
+
10
+ class PortManager:
11
+ def __init__(self, ports_range=(1024, 65535)):
12
+ self.ports_range = ports_range
13
+
14
+ @staticmethod
15
+ def is_port_free(port: int) -> bool:
16
+ """Check if a port is free to use."""
17
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
18
+ try:
19
+ s.bind(("", port))
20
+ return True
21
+ except OSError:
22
+ return False
23
+
24
+ def find_free_port(self, ports: Optional[Set[int]] = None, exclude_ports: Optional[Iterable[int]] = None) -> int:
25
+ """Return a random unused port number."""
26
+ if ports is None:
27
+ ports = set(range(1024, 65535))
28
+
29
+ if exclude_ports is None:
30
+ exclude_ports = set()
31
+
32
+ ports.difference_update(set(exclude_ports))
33
+
34
+ sampled_ports = random.sample(tuple(ports), min(len(ports), 100))
35
+
36
+ for port in sampled_ports:
37
+ if self.is_port_free(port):
38
+ return port
39
+
40
+ raise PortForException("Can't select a port")
@@ -623,8 +623,8 @@ class PostgresNode(object):
623
623
  "-D", self.data_dir,
624
624
  "status"
625
625
  ] # yapf: disable
626
- status_code, out, err = execute_utility(_params, self.utils_log_file, verbose=True)
627
- if 'does not exist' in err:
626
+ status_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True)
627
+ if error and 'does not exist' in error:
628
628
  return NodeStatus.Uninitialized
629
629
  elif 'no server running' in out:
630
630
  return NodeStatus.Stopped
@@ -659,7 +659,7 @@ class PostgresNode(object):
659
659
 
660
660
  return out_dict
661
661
 
662
- def slow_start(self, replica=False, dbname='template1', username=default_username(), max_attempts=0):
662
+ def slow_start(self, replica=False, dbname='template1', username=None, max_attempts=0):
663
663
  """
664
664
  Starts the PostgreSQL instance and then polls the instance
665
665
  until it reaches the expected state (primary or replica). The state is checked
@@ -672,6 +672,8 @@ class PostgresNode(object):
672
672
  If False, waits for the instance to be in primary mode. Default is False.
673
673
  max_attempts:
674
674
  """
675
+ if not username:
676
+ username = default_username()
675
677
  self.start()
676
678
 
677
679
  if replica:
@@ -715,7 +717,7 @@ class PostgresNode(object):
715
717
 
716
718
  try:
717
719
  exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True)
718
- if 'does not exist' in error:
720
+ if error and 'does not exist' in error:
719
721
  raise Exception
720
722
  except Exception as e:
721
723
  msg = 'Cannot start node'
@@ -789,7 +791,7 @@ class PostgresNode(object):
789
791
 
790
792
  try:
791
793
  error_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True)
792
- if 'could not start server' in error:
794
+ if error and 'could not start server' in error:
793
795
  raise ExecUtilException
794
796
  except ExecUtilException as e:
795
797
  msg = 'Cannot restart node'
@@ -1371,7 +1373,7 @@ class PostgresNode(object):
1371
1373
  # should be the last one
1372
1374
  _params.append(dbname)
1373
1375
 
1374
- proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, proc=True)
1376
+ proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, get_process=True)
1375
1377
 
1376
1378
  return proc
1377
1379
 
File without changes
@@ -8,8 +8,7 @@ import tempfile
8
8
  import psutil
9
9
 
10
10
  from ..exceptions import ExecUtilException
11
- from .os_ops import ConnectionParams, OsOperations
12
- from .os_ops import pglib
11
+ from .os_ops import ConnectionParams, OsOperations, pglib, get_default_encoding
13
12
 
14
13
  try:
15
14
  from shutil import which as find_executable
@@ -18,11 +17,18 @@ except ImportError:
18
17
  from distutils.spawn import find_executable
19
18
  from distutils import rmtree
20
19
 
21
-
22
20
  CMD_TIMEOUT_SEC = 60
23
21
  error_markers = [b'error', b'Permission denied', b'fatal']
24
22
 
25
23
 
24
+ def has_errors(output):
25
+ if output:
26
+ if isinstance(output, str):
27
+ output = output.encode(get_default_encoding())
28
+ return any(marker in output for marker in error_markers)
29
+ return False
30
+
31
+
26
32
  class LocalOperations(OsOperations):
27
33
  def __init__(self, conn_params=None):
28
34
  if conn_params is None:
@@ -34,66 +40,80 @@ class LocalOperations(OsOperations):
34
40
  self.remote = False
35
41
  self.username = conn_params.username or self.get_user()
36
42
 
37
- # Command execution
38
- def exec_command(self, cmd, wait_exit=False, verbose=False,
39
- expect_error=False, encoding=None, shell=False, text=False,
40
- input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None):
41
- """
42
- Execute a command in a subprocess.
43
-
44
- Args:
45
- - cmd: The command to execute.
46
- - wait_exit: Whether to wait for the subprocess to exit before returning.
47
- - verbose: Whether to return verbose output.
48
- - expect_error: Whether to raise an error if the subprocess exits with an error status.
49
- - encoding: The encoding to use for decoding the subprocess output.
50
- - shell: Whether to use shell when executing the subprocess.
51
- - text: Whether to return str instead of bytes for the subprocess output.
52
- - input: The input to pass to the subprocess.
53
- - stdout: The stdout to use for the subprocess.
54
- - stderr: The stderr to use for the subprocess.
55
- - proc: The process to use for subprocess creation.
56
- :return: The output of the subprocess.
57
- """
58
- if os.name == 'nt':
59
- with tempfile.NamedTemporaryFile() as buf:
60
- process = subprocess.Popen(cmd, stdout=buf, stderr=subprocess.STDOUT)
61
- process.communicate()
62
- buf.seek(0)
63
- result = buf.read().decode(encoding)
64
- return result
65
- else:
43
+ @staticmethod
44
+ def _raise_exec_exception(message, command, exit_code, output):
45
+ """Raise an ExecUtilException."""
46
+ raise ExecUtilException(message=message.format(output),
47
+ command=command,
48
+ exit_code=exit_code,
49
+ out=output)
50
+
51
+ @staticmethod
52
+ def _process_output(encoding, temp_file_path):
53
+ """Process the output of a command from a temporary file."""
54
+ with open(temp_file_path, 'rb') as temp_file:
55
+ output = temp_file.read()
56
+ if encoding:
57
+ output = output.decode(encoding)
58
+ return output, None # In Windows stderr writing in stdout
59
+
60
+ def _run_command(self, cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding):
61
+ """Execute a command and return the process and its output."""
62
+ if os.name == 'nt' and stdout is None: # Windows
63
+ with tempfile.NamedTemporaryFile(mode='w+b', delete=False) as temp_file:
64
+ stdout = temp_file
65
+ stderr = subprocess.STDOUT
66
+ process = subprocess.Popen(
67
+ cmd,
68
+ shell=shell,
69
+ stdin=stdin or subprocess.PIPE if input is not None else None,
70
+ stdout=stdout,
71
+ stderr=stderr,
72
+ )
73
+ if get_process:
74
+ return process, None, None
75
+ temp_file_path = temp_file.name
76
+
77
+ # Wait process finished
78
+ process.wait()
79
+
80
+ output, error = self._process_output(encoding, temp_file_path)
81
+ return process, output, error
82
+ else: # Other OS
66
83
  process = subprocess.Popen(
67
84
  cmd,
68
85
  shell=shell,
69
- stdout=stdout,
70
- stderr=stderr,
86
+ stdin=stdin or subprocess.PIPE if input is not None else None,
87
+ stdout=stdout or subprocess.PIPE,
88
+ stderr=stderr or subprocess.PIPE,
71
89
  )
72
- if proc:
73
- return process
74
- result, error = process.communicate(input)
75
- exit_status = process.returncode
76
-
77
- error_found = exit_status != 0 or any(marker in error for marker in error_markers)
78
-
79
- if encoding:
80
- result = result.decode(encoding)
81
- error = error.decode(encoding)
82
-
83
- if expect_error:
84
- raise Exception(result, error)
85
-
86
- if exit_status != 0 or error_found:
87
- if exit_status == 0:
88
- exit_status = 1
89
- raise ExecUtilException(message='Utility exited with non-zero code. Error `{}`'.format(error),
90
- command=cmd,
91
- exit_code=exit_status,
92
- out=result)
93
- if verbose:
94
- return exit_status, result, error
95
- else:
96
- return result
90
+ if get_process:
91
+ return process, None, None
92
+ try:
93
+ output, error = process.communicate(input=input.encode(encoding) if input else None, timeout=timeout)
94
+ if encoding:
95
+ output = output.decode(encoding)
96
+ error = error.decode(encoding)
97
+ return process, output, error
98
+ except subprocess.TimeoutExpired:
99
+ process.kill()
100
+ raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
101
+
102
+ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=False,
103
+ text=False, input=None, stdin=None, stdout=None, stderr=None, get_process=False, timeout=None):
104
+ """
105
+ Execute a command in a subprocess and handle the output based on the provided parameters.
106
+ """
107
+ process, output, error = self._run_command(cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding)
108
+ if get_process:
109
+ return process
110
+ if process.returncode != 0 or (has_errors(error) and not expect_error):
111
+ self._raise_exec_exception('Utility exited with non-zero code. Error `{}`', cmd, process.returncode, error)
112
+
113
+ if verbose:
114
+ return process.returncode, output, error
115
+ else:
116
+ return output
97
117
 
98
118
  # Environment setup
99
119
  def environ(self, var_name):
@@ -112,7 +132,7 @@ class LocalOperations(OsOperations):
112
132
 
113
133
  # Get environment variables
114
134
  def get_user(self):
115
- return getpass.getuser()
135
+ return self.username or getpass.getuser()
116
136
 
117
137
  def get_name(self):
118
138
  return os.name
@@ -205,7 +225,7 @@ class LocalOperations(OsOperations):
205
225
  if binary:
206
226
  return content
207
227
  if isinstance(content, bytes):
208
- return content.decode(encoding or 'utf-8')
228
+ return content.decode(encoding or get_default_encoding())
209
229
  return content
210
230
 
211
231
  def readlines(self, filename, num_lines=0, binary=False, encoding=None):
@@ -1,3 +1,5 @@
1
+ import locale
2
+
1
3
  try:
2
4
  import psycopg2 as pglib # noqa: F401
3
5
  except ImportError:
@@ -14,6 +16,10 @@ class ConnectionParams:
14
16
  self.username = username
15
17
 
16
18
 
19
+ def get_default_encoding():
20
+ return locale.getdefaultlocale()[1] or 'UTF-8'
21
+
22
+
17
23
  class OsOperations:
18
24
  def __init__(self, username=None):
19
25
  self.ssh_key = None
@@ -75,7 +81,7 @@ class OsOperations:
75
81
  def touch(self, filename):
76
82
  raise NotImplementedError()
77
83
 
78
- def read(self, filename):
84
+ def read(self, filename, encoding, binary):
79
85
  raise NotImplementedError()
80
86
 
81
87
  def readlines(self, filename):
@@ -1,23 +1,20 @@
1
- import locale
2
1
  import logging
3
2
  import os
4
3
  import subprocess
5
4
  import tempfile
6
- import time
5
+ import platform
7
6
 
8
- import sshtunnel
7
+ # we support both pg8000 and psycopg2
8
+ try:
9
+ import psycopg2 as pglib
10
+ except ImportError:
11
+ try:
12
+ import pg8000 as pglib
13
+ except ImportError:
14
+ raise ImportError("You must have psycopg2 or pg8000 modules installed")
9
15
 
10
16
  from ..exceptions import ExecUtilException
11
-
12
- from .os_ops import OsOperations, ConnectionParams
13
- from .os_ops import pglib
14
-
15
- sshtunnel.SSH_TIMEOUT = 5.0
16
- sshtunnel.TUNNEL_TIMEOUT = 5.0
17
-
18
- ConsoleEncoding = locale.getdefaultlocale()[1]
19
- if not ConsoleEncoding:
20
- ConsoleEncoding = 'UTF-8'
17
+ from .os_ops import OsOperations, ConnectionParams, get_default_encoding
21
18
 
22
19
  error_markers = [b'error', b'Permission denied', b'fatal', b'No such file or directory']
23
20
 
@@ -33,66 +30,84 @@ class PsUtilProcessProxy:
33
30
 
34
31
  def cmdline(self):
35
32
  command = "ps -p {} -o cmd --no-headers".format(self.pid)
36
- stdin, stdout, stderr = self.ssh.exec_command(command, verbose=True, encoding=ConsoleEncoding)
33
+ stdin, stdout, stderr = self.ssh.exec_command(command, verbose=True, encoding=get_default_encoding())
37
34
  cmdline = stdout.strip()
38
35
  return cmdline.split()
39
36
 
40
37
 
41
38
  class RemoteOperations(OsOperations):
42
39
  def __init__(self, conn_params: ConnectionParams):
43
- if os.name != "posix":
40
+
41
+ if not platform.system().lower() == "linux":
44
42
  raise EnvironmentError("Remote operations are supported only on Linux!")
45
43
 
46
44
  super().__init__(conn_params.username)
47
45
  self.conn_params = conn_params
48
46
  self.host = conn_params.host
49
47
  self.ssh_key = conn_params.ssh_key
48
+ if self.ssh_key:
49
+ self.ssh_cmd = ["-i", self.ssh_key]
50
+ else:
51
+ self.ssh_cmd = []
50
52
  self.remote = True
51
53
  self.username = conn_params.username or self.get_user()
52
54
  self.add_known_host(self.host)
55
+ self.tunnel_process = None
53
56
 
54
57
  def __enter__(self):
55
58
  return self
56
59
 
57
60
  def __exit__(self, exc_type, exc_val, exc_tb):
58
- self.close_tunnel()
61
+ self.close_ssh_tunnel()
59
62
 
60
- def close_tunnel(self):
61
- if getattr(self, 'tunnel', None):
62
- self.tunnel.stop(force=True)
63
- start_time = time.time()
64
- while self.tunnel.is_active:
65
- if time.time() - start_time > sshtunnel.TUNNEL_TIMEOUT:
66
- break
67
- time.sleep(0.5)
63
+ def establish_ssh_tunnel(self, local_port, remote_port):
64
+ """
65
+ Establish an SSH tunnel from a local port to a remote PostgreSQL port.
66
+ """
67
+ ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"]
68
+ self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300)
69
+
70
+ def close_ssh_tunnel(self):
71
+ if hasattr(self, 'tunnel_process'):
72
+ self.tunnel_process.terminate()
73
+ self.tunnel_process.wait()
74
+ del self.tunnel_process
75
+ else:
76
+ print("No active tunnel to close.")
68
77
 
69
78
  def add_known_host(self, host):
70
- cmd = 'ssh-keyscan -H %s >> /home/%s/.ssh/known_hosts' % (host, os.getlogin())
79
+ known_hosts_path = os.path.expanduser("~/.ssh/known_hosts")
80
+ cmd = 'ssh-keyscan -H %s >> %s' % (host, known_hosts_path)
81
+
71
82
  try:
72
- subprocess.check_call(
73
- cmd,
74
- shell=True,
75
- )
83
+ subprocess.check_call(cmd, shell=True)
76
84
  logging.info("Successfully added %s to known_hosts." % host)
77
85
  except subprocess.CalledProcessError as e:
78
- raise ExecUtilException(message="Failed to add %s to known_hosts. Error: %s" % (host, str(e)), command=cmd,
79
- exit_code=e.returncode, out=e.stderr)
86
+ raise Exception("Failed to add %s to known_hosts. Error: %s" % (host, str(e)))
80
87
 
81
- def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False,
88
+ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
82
89
  encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None,
83
- stderr=None, proc=None):
90
+ stderr=None, get_process=None, timeout=None):
84
91
  """
85
92
  Execute a command in the SSH session.
86
93
  Args:
87
94
  - cmd (str): The command to be executed.
88
95
  """
96
+ ssh_cmd = []
89
97
  if isinstance(cmd, str):
90
- ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key, cmd]
98
+ ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + [cmd]
91
99
  elif isinstance(cmd, list):
92
- ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key] + cmd
100
+ ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + cmd
93
101
  process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
102
+ if get_process:
103
+ return process
104
+
105
+ try:
106
+ result, error = process.communicate(input, timeout=timeout)
107
+ except subprocess.TimeoutExpired:
108
+ process.kill()
109
+ raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
94
110
 
95
- result, error = process.communicate(input)
96
111
  exit_status = process.returncode
97
112
 
98
113
  if encoding:
@@ -128,7 +143,7 @@ class RemoteOperations(OsOperations):
128
143
  - var_name (str): The name of the environment variable.
129
144
  """
130
145
  cmd = "echo ${}".format(var_name)
131
- return self.exec_command(cmd, encoding=ConsoleEncoding).strip()
146
+ return self.exec_command(cmd, encoding=get_default_encoding()).strip()
132
147
 
133
148
  def find_executable(self, executable):
134
149
  search_paths = self.environ("PATH")
@@ -159,11 +174,11 @@ class RemoteOperations(OsOperations):
159
174
 
160
175
  # Get environment variables
161
176
  def get_user(self):
162
- return self.exec_command("echo $USER", encoding=ConsoleEncoding).strip()
177
+ return self.exec_command("echo $USER", encoding=get_default_encoding()).strip()
163
178
 
164
179
  def get_name(self):
165
180
  cmd = 'python3 -c "import os; print(os.name)"'
166
- return self.exec_command(cmd, encoding=ConsoleEncoding).strip()
181
+ return self.exec_command(cmd, encoding=get_default_encoding()).strip()
167
182
 
168
183
  # Work with dirs
169
184
  def makedirs(self, path, remove_existing=False):
@@ -210,7 +225,7 @@ class RemoteOperations(OsOperations):
210
225
  return result.splitlines()
211
226
 
212
227
  def path_exists(self, path):
213
- result = self.exec_command("test -e {}; echo $?".format(path), encoding=ConsoleEncoding)
228
+ result = self.exec_command("test -e {}; echo $?".format(path), encoding=get_default_encoding())
214
229
  return int(result.strip()) == 0
215
230
 
216
231
  @property
@@ -231,9 +246,9 @@ class RemoteOperations(OsOperations):
231
246
  - prefix (str): The prefix of the temporary directory name.
232
247
  """
233
248
  if prefix:
234
- command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
249
+ command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
235
250
  else:
236
- command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", "mktemp -d"]
251
+ command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", "mktemp -d"]
237
252
 
238
253
  result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
239
254
 
@@ -247,9 +262,9 @@ class RemoteOperations(OsOperations):
247
262
 
248
263
  def mkstemp(self, prefix=None):
249
264
  if prefix:
250
- temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding=ConsoleEncoding)
265
+ temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding=get_default_encoding())
251
266
  else:
252
- temp_dir = self.exec_command("mktemp", encoding=ConsoleEncoding)
267
+ temp_dir = self.exec_command("mktemp", encoding=get_default_encoding())
253
268
 
254
269
  if temp_dir:
255
270
  if not os.path.isabs(temp_dir):
@@ -266,7 +281,9 @@ class RemoteOperations(OsOperations):
266
281
  return self.exec_command("cp -r {} {}".format(src, dst))
267
282
 
268
283
  # Work with files
269
- def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding=ConsoleEncoding):
284
+ def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding=None):
285
+ if not encoding:
286
+ encoding = get_default_encoding()
270
287
  mode = "wb" if binary else "w"
271
288
  if not truncate:
272
289
  mode = "ab" if binary else "a"
@@ -275,7 +292,7 @@ class RemoteOperations(OsOperations):
275
292
 
276
293
  with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file:
277
294
  if not truncate:
278
- scp_cmd = ['scp', '-i', self.ssh_key, f"{self.username}@{self.host}:{filename}", tmp_file.name]
295
+ scp_cmd = ['scp'] + self.ssh_cmd + [f"{self.username}@{self.host}:{filename}", tmp_file.name]
279
296
  subprocess.run(scp_cmd, check=False) # The file might not exist yet
280
297
  tmp_file.seek(0, os.SEEK_END)
281
298
 
@@ -285,18 +302,17 @@ class RemoteOperations(OsOperations):
285
302
  data = data.encode(encoding)
286
303
 
287
304
  if isinstance(data, list):
288
- data = [(s if isinstance(s, str) else s.decode(ConsoleEncoding)).rstrip('\n') + '\n' for s in data]
305
+ data = [(s if isinstance(s, str) else s.decode(get_default_encoding())).rstrip('\n') + '\n' for s in data]
289
306
  tmp_file.writelines(data)
290
307
  else:
291
308
  tmp_file.write(data)
292
309
 
293
310
  tmp_file.flush()
294
-
295
- scp_cmd = ['scp', '-i', self.ssh_key, tmp_file.name, f"{self.username}@{self.host}:{filename}"]
311
+ scp_cmd = ['scp'] + self.ssh_cmd + [tmp_file.name, f"{self.username}@{self.host}:{filename}"]
296
312
  subprocess.run(scp_cmd, check=True)
297
313
 
298
314
  remote_directory = os.path.dirname(filename)
299
- mkdir_cmd = ['ssh', '-i', self.ssh_key, f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"]
315
+ mkdir_cmd = ['ssh'] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"]
300
316
  subprocess.run(mkdir_cmd, check=True)
301
317
 
302
318
  os.remove(tmp_file.name)
@@ -317,7 +333,7 @@ class RemoteOperations(OsOperations):
317
333
  result = self.exec_command(cmd, encoding=encoding)
318
334
 
319
335
  if not binary and result:
320
- result = result.decode(encoding or ConsoleEncoding)
336
+ result = result.decode(encoding or get_default_encoding())
321
337
 
322
338
  return result
323
339
 
@@ -330,7 +346,7 @@ class RemoteOperations(OsOperations):
330
346
  result = self.exec_command(cmd, encoding=encoding)
331
347
 
332
348
  if not binary and result:
333
- lines = result.decode(encoding or ConsoleEncoding).splitlines()
349
+ lines = result.decode(encoding or get_default_encoding()).splitlines()
334
350
  else:
335
351
  lines = result.splitlines()
336
352
 
@@ -358,10 +374,10 @@ class RemoteOperations(OsOperations):
358
374
 
359
375
  def get_pid(self):
360
376
  # Get current process id
361
- return int(self.exec_command("echo $$", encoding=ConsoleEncoding))
377
+ return int(self.exec_command("echo $$", encoding=get_default_encoding()))
362
378
 
363
379
  def get_process_children(self, pid):
364
- command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"pgrep -P {pid}"]
380
+ command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"pgrep -P {pid}"]
365
381
 
366
382
  result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
367
383
 
@@ -372,41 +388,19 @@ class RemoteOperations(OsOperations):
372
388
  raise ExecUtilException(f"Error in getting process children. Error: {result.stderr}")
373
389
 
374
390
  # Database control
375
- def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, ssh_key=None):
391
+ def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
376
392
  """
377
- Connects to a PostgreSQL database on the remote system.
378
- Args:
379
- - dbname (str): The name of the database to connect to.
380
- - user (str): The username for the database connection.
381
- - password (str, optional): The password for the database connection. Defaults to None.
382
- - host (str, optional): The IP address of the remote system. Defaults to "localhost".
383
- - port (int, optional): The port number of the PostgreSQL service. Defaults to 5432.
384
-
385
- This function establishes a connection to a PostgreSQL database on the remote system using the specified
386
- parameters. It returns a connection object that can be used to interact with the database.
393
+ Established SSH tunnel and Connects to a PostgreSQL
387
394
  """
388
- self.close_tunnel()
389
- self.tunnel = sshtunnel.open_tunnel(
390
- (self.host, 22), # Remote server IP and SSH port
391
- ssh_username=self.username,
392
- ssh_pkey=self.ssh_key,
393
- remote_bind_address=(self.host, port), # PostgreSQL server IP and PostgreSQL port
394
- local_bind_address=('localhost', 0)
395
- # Local machine IP and available port (0 means it will pick any available port)
396
- )
397
- self.tunnel.start()
398
-
395
+ self.establish_ssh_tunnel(local_port=port, remote_port=5432)
399
396
  try:
400
- # Use localhost and self.tunnel.local_bind_port to connect
401
397
  conn = pglib.connect(
402
- host='localhost', # Connect to localhost
403
- port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel
398
+ host=host,
399
+ port=port,
404
400
  database=dbname,
405
- user=user or self.username,
406
- password=password
401
+ user=user,
402
+ password=password,
407
403
  )
408
-
409
404
  return conn
410
405
  except Exception as e:
411
- self.tunnel.stop()
412
- raise ExecUtilException("Could not create db tunnel. {}".format(e))
406
+ raise Exception(f"Could not connect to the database. Error: {e}")
@@ -4,7 +4,7 @@ from __future__ import division
4
4
  from __future__ import print_function
5
5
 
6
6
  import os
7
- import port_for
7
+
8
8
  import sys
9
9
 
10
10
  from contextlib import contextmanager
@@ -13,6 +13,7 @@ import re
13
13
 
14
14
  from six import iteritems
15
15
 
16
+ from .helpers.port_manager import PortManager
16
17
  from .exceptions import ExecUtilException
17
18
  from .config import testgres_config as tconf
18
19
 
@@ -37,8 +38,8 @@ def reserve_port():
37
38
  """
38
39
  Generate a new port and add it to 'bound_ports'.
39
40
  """
40
-
41
- port = port_for.select_random(exclude_ports=bound_ports)
41
+ port_mng = PortManager()
42
+ port = port_mng.find_free_port(exclude_ports=bound_ports)
42
43
  bound_ports.add(port)
43
44
 
44
45
  return port
@@ -80,7 +81,8 @@ def execute_utility(args, logfile=None, verbose=False):
80
81
  lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n']
81
82
  tconf.os_ops.write(filename=logfile, data=lines)
82
83
  except IOError:
83
- raise ExecUtilException("Problem with writing to logfile `{}` during run command `{}`".format(logfile, args))
84
+ raise ExecUtilException(
85
+ "Problem with writing to logfile `{}` during run command `{}`".format(logfile, args))
84
86
  if verbose:
85
87
  return exit_status, out, error
86
88
  else:
@@ -179,6 +181,9 @@ def get_pg_version():
179
181
  _params = [get_bin_path('postgres'), '--version']
180
182
  raw_ver = tconf.os_ops.exec_command(_params, encoding='utf-8')
181
183
 
184
+ # Remove "(Homebrew)" if present
185
+ raw_ver = raw_ver.replace('(Homebrew)', '').strip()
186
+
182
187
  # cook version of PostgreSQL
183
188
  version = raw_ver.strip().split(' ')[-1] \
184
189
  .partition('devel')[0] \
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: testgres
3
- Version: 1.9.1
3
+ Version: 1.9.3
4
4
  Summary: Testing utility for PostgreSQL and its extensions
5
5
  Home-page: https://github.com/postgrespro/testgres
6
- Author: Ildar Musin
7
- Author-email: zildermann@gmail.com
6
+ Author: Postgres Professional
7
+ Author-email: testgres@postgrespro.ru
8
8
  License: PostgreSQL
9
9
  Keywords: test,testing,postgresql
10
10
  Platform: UNKNOWN
@@ -24,6 +24,8 @@ testgres.egg-info/SOURCES.txt
24
24
  testgres.egg-info/dependency_links.txt
25
25
  testgres.egg-info/requires.txt
26
26
  testgres.egg-info/top_level.txt
27
+ testgres/helpers/__init__.py
28
+ testgres/helpers/port_manager.py
27
29
  testgres/operations/__init__.py
28
30
  testgres/operations/local_ops.py
29
31
  testgres/operations/os_ops.py
@@ -1,7 +1,5 @@
1
- fabric
2
1
  packaging
3
2
  pg8000
4
3
  port-for>=0.4
5
4
  psutil
6
5
  six>=1.9.0
7
- sshtunnel
@@ -11,10 +11,9 @@ class TestRemoteOperations:
11
11
 
12
12
  @pytest.fixture(scope="function", autouse=True)
13
13
  def setup(self):
14
- conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3',
15
- username='dev',
16
- ssh_key=os.getenv(
17
- 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519')
14
+ conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1',
15
+ username=os.getenv('USER'),
16
+ ssh_key=os.getenv('RDBMS_TESTPOOL_SSHKEY'))
18
17
  self.operations = RemoteOperations(conn_params)
19
18
 
20
19
  def test_exec_command_success(self):
@@ -41,7 +40,7 @@ class TestRemoteOperations:
41
40
  """
42
41
  Test is_executable for an existing executable.
43
42
  """
44
- cmd = "postgres"
43
+ cmd = os.getenv('PG_CONFIG')
45
44
  response = self.operations.is_executable(cmd)
46
45
 
47
46
  assert response is True
@@ -74,6 +74,24 @@ def util_exists(util):
74
74
  return True
75
75
 
76
76
 
77
+ def rm_carriage_returns(out):
78
+ """
79
+ In Windows we have additional '\r' symbols in output.
80
+ Let's get rid of them.
81
+ """
82
+ if os.name == 'nt':
83
+ if isinstance(out, (int, float, complex)):
84
+ return out
85
+ elif isinstance(out, tuple):
86
+ return tuple(rm_carriage_returns(item) for item in out)
87
+ elif isinstance(out, bytes):
88
+ return out.replace(b'\r', b'')
89
+ else:
90
+ return out.replace('\r', '')
91
+ else:
92
+ return out
93
+
94
+
77
95
  @contextmanager
78
96
  def removing(f):
79
97
  try:
@@ -123,7 +141,7 @@ class TestgresTests(unittest.TestCase):
123
141
  node.cleanup()
124
142
  node.init().start().execute('select 1')
125
143
 
126
- @unittest.skipUnless(util_exists('pg_resetwal'), 'might be missing')
144
+ @unittest.skipUnless(util_exists('pg_resetwal.exe' if os.name == 'nt' else 'pg_resetwal'), 'pgbench might be missing')
127
145
  @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+')
128
146
  def test_init_unique_system_id(self):
129
147
  # this function exists in PostgreSQL 9.6+
@@ -254,34 +272,34 @@ class TestgresTests(unittest.TestCase):
254
272
 
255
273
  # check returned values (1 arg)
256
274
  res = node.psql('select 1')
257
- self.assertEqual(res, (0, b'1\n', b''))
275
+ self.assertEqual(rm_carriage_returns(res), (0, b'1\n', b''))
258
276
 
259
277
  # check returned values (2 args)
260
278
  res = node.psql('postgres', 'select 2')
261
- self.assertEqual(res, (0, b'2\n', b''))
279
+ self.assertEqual(rm_carriage_returns(res), (0, b'2\n', b''))
262
280
 
263
281
  # check returned values (named)
264
282
  res = node.psql(query='select 3', dbname='postgres')
265
- self.assertEqual(res, (0, b'3\n', b''))
283
+ self.assertEqual(rm_carriage_returns(res), (0, b'3\n', b''))
266
284
 
267
285
  # check returned values (1 arg)
268
286
  res = node.safe_psql('select 4')
269
- self.assertEqual(res, b'4\n')
287
+ self.assertEqual(rm_carriage_returns(res), b'4\n')
270
288
 
271
289
  # check returned values (2 args)
272
290
  res = node.safe_psql('postgres', 'select 5')
273
- self.assertEqual(res, b'5\n')
291
+ self.assertEqual(rm_carriage_returns(res), b'5\n')
274
292
 
275
293
  # check returned values (named)
276
294
  res = node.safe_psql(query='select 6', dbname='postgres')
277
- self.assertEqual(res, b'6\n')
295
+ self.assertEqual(rm_carriage_returns(res), b'6\n')
278
296
 
279
297
  # check feeding input
280
298
  node.safe_psql('create table horns (w int)')
281
299
  node.safe_psql('copy horns from stdin (format csv)',
282
300
  input=b"1\n2\n3\n\\.\n")
283
301
  _sum = node.safe_psql('select sum(w) from horns')
284
- self.assertEqual(_sum, b'6\n')
302
+ self.assertEqual(rm_carriage_returns(_sum), b'6\n')
285
303
 
286
304
  # check psql's default args, fails
287
305
  with self.assertRaises(QueryException):
@@ -455,7 +473,7 @@ class TestgresTests(unittest.TestCase):
455
473
  master.safe_psql(
456
474
  'insert into abc select generate_series(1, 1000000)')
457
475
  res = standby1.safe_psql('select count(*) from abc')
458
- self.assertEqual(res, b'1000000\n')
476
+ self.assertEqual(rm_carriage_returns(res), b'1000000\n')
459
477
 
460
478
  @unittest.skipUnless(pg_version_ge('10'), 'requires 10+')
461
479
  def test_logical_replication(self):
@@ -589,7 +607,7 @@ class TestgresTests(unittest.TestCase):
589
607
  # make standby becomes writable master
590
608
  replica.safe_psql('insert into abc values (1)')
591
609
  res = replica.safe_psql('select * from abc')
592
- self.assertEqual(res, b'1\n')
610
+ self.assertEqual(rm_carriage_returns(res), b'1\n')
593
611
 
594
612
  def test_dump(self):
595
613
  query_create = 'create table test as select generate_series(1, 2) as val'
@@ -614,6 +632,7 @@ class TestgresTests(unittest.TestCase):
614
632
  with get_new_node().init().start() as node:
615
633
  node.psql('create role test_user login')
616
634
  value = node.safe_psql('select 1', username='test_user')
635
+ value = rm_carriage_returns(value)
617
636
  self.assertEqual(value, b'1\n')
618
637
 
619
638
  def test_poll_query_until(self):
@@ -728,7 +747,7 @@ class TestgresTests(unittest.TestCase):
728
747
  master.restart()
729
748
  self.assertTrue(master._logger.is_alive())
730
749
 
731
- @unittest.skipUnless(util_exists('pgbench'), 'might be missing')
750
+ @unittest.skipUnless(util_exists('pgbench.exe' if os.name == 'nt' else 'pgbench'), 'pgbench might be missing')
732
751
  def test_pgbench(self):
733
752
  with get_new_node().init().start() as node:
734
753
 
@@ -744,6 +763,8 @@ class TestgresTests(unittest.TestCase):
744
763
  out, _ = proc.communicate()
745
764
  out = out.decode('utf-8')
746
765
 
766
+ proc.stdout.close()
767
+
747
768
  self.assertTrue('tps' in out)
748
769
 
749
770
  def test_pg_config(self):
@@ -977,7 +998,9 @@ class TestgresTests(unittest.TestCase):
977
998
 
978
999
  def test_child_process_dies(self):
979
1000
  # test for FileNotFound exception during child_processes() function
980
- with subprocess.Popen(["sleep", "60"]) as process:
1001
+ cmd = ["timeout", "60"] if os.name == 'nt' else ["sleep", "60"]
1002
+
1003
+ with subprocess.Popen(cmd, shell=True) as process: # shell=True might be needed on Windows
981
1004
  self.assertEqual(process.poll(), None)
982
1005
  # collect list of processes currently running
983
1006
  children = psutil.Process(os.getpid()).children()
@@ -52,10 +52,9 @@ from testgres import bound_ports
52
52
  from testgres.utils import PgVer
53
53
  from testgres.node import ProcessProxy, ConnectionParams
54
54
 
55
- conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3',
56
- username='dev',
57
- ssh_key=os.getenv(
58
- 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519')
55
+ conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1',
56
+ username=os.getenv('USER'),
57
+ ssh_key=os.getenv('RDBMS_TESTPOOL_SSHKEY'))
59
58
  os_ops = RemoteOperations(conn_params)
60
59
  testgres_config.set_os_ops(os_ops=os_ops)
61
60
 
@@ -735,9 +734,10 @@ class TestgresRemoteTests(unittest.TestCase):
735
734
  options=['-q']).pgbench_run(time=2)
736
735
 
737
736
  # run TPC-B benchmark
738
- out = node.pgbench(stdout=subprocess.PIPE,
739
- stderr=subprocess.STDOUT,
740
- options=['-T3'])
737
+ proc = node.pgbench(stdout=subprocess.PIPE,
738
+ stderr=subprocess.STDOUT,
739
+ options=['-T3'])
740
+ out = proc.communicate()[0]
741
741
  self.assertTrue(b'tps = ' in out)
742
742
 
743
743
  def test_pg_config(self):
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes