pyinfra 3.2__py2.py3-none-any.whl → 3.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. pyinfra/api/arguments_typed.py +4 -5
  2. pyinfra/api/command.py +22 -3
  3. pyinfra/api/config.py +5 -2
  4. pyinfra/api/facts.py +3 -0
  5. pyinfra/api/host.py +10 -4
  6. pyinfra/api/operation.py +2 -1
  7. pyinfra/api/state.py +1 -1
  8. pyinfra/connectors/base.py +34 -8
  9. pyinfra/connectors/chroot.py +7 -2
  10. pyinfra/connectors/docker.py +7 -2
  11. pyinfra/connectors/dockerssh.py +7 -2
  12. pyinfra/connectors/local.py +7 -2
  13. pyinfra/connectors/ssh.py +9 -2
  14. pyinfra/connectors/sshuserclient/client.py +16 -0
  15. pyinfra/connectors/sshuserclient/config.py +2 -0
  16. pyinfra/connectors/terraform.py +1 -1
  17. pyinfra/connectors/util.py +13 -9
  18. pyinfra/context.py +9 -2
  19. pyinfra/facts/apk.py +5 -0
  20. pyinfra/facts/apt.py +9 -1
  21. pyinfra/facts/brew.py +13 -0
  22. pyinfra/facts/bsdinit.py +3 -0
  23. pyinfra/facts/cargo.py +5 -0
  24. pyinfra/facts/choco.py +6 -0
  25. pyinfra/facts/crontab.py +6 -1
  26. pyinfra/facts/deb.py +10 -0
  27. pyinfra/facts/dnf.py +5 -0
  28. pyinfra/facts/docker.py +10 -0
  29. pyinfra/facts/efibootmgr.py +5 -0
  30. pyinfra/facts/files.py +19 -1
  31. pyinfra/facts/flatpak.py +7 -0
  32. pyinfra/facts/freebsd.py +75 -0
  33. pyinfra/facts/gem.py +5 -0
  34. pyinfra/facts/git.py +9 -0
  35. pyinfra/facts/gpg.py +7 -0
  36. pyinfra/facts/hardware.py +13 -0
  37. pyinfra/facts/iptables.py +9 -1
  38. pyinfra/facts/launchd.py +5 -0
  39. pyinfra/facts/lxd.py +5 -0
  40. pyinfra/facts/mysql.py +8 -0
  41. pyinfra/facts/npm.py +5 -0
  42. pyinfra/facts/openrc.py +8 -0
  43. pyinfra/facts/opkg.py +12 -0
  44. pyinfra/facts/pacman.py +9 -1
  45. pyinfra/facts/pip.py +5 -0
  46. pyinfra/facts/pipx.py +8 -0
  47. pyinfra/facts/pkg.py +4 -0
  48. pyinfra/facts/pkgin.py +5 -0
  49. pyinfra/facts/podman.py +7 -0
  50. pyinfra/facts/postgres.py +8 -2
  51. pyinfra/facts/rpm.py +11 -0
  52. pyinfra/facts/runit.py +7 -0
  53. pyinfra/facts/selinux.py +16 -0
  54. pyinfra/facts/server.py +49 -3
  55. pyinfra/facts/snap.py +7 -0
  56. pyinfra/facts/systemd.py +5 -0
  57. pyinfra/facts/sysvinit.py +4 -0
  58. pyinfra/facts/upstart.py +5 -0
  59. pyinfra/facts/util/__init__.py +4 -1
  60. pyinfra/facts/vzctl.py +5 -0
  61. pyinfra/facts/xbps.py +6 -1
  62. pyinfra/facts/yum.py +5 -0
  63. pyinfra/facts/zfs.py +19 -2
  64. pyinfra/facts/zypper.py +5 -0
  65. pyinfra/operations/apt.py +10 -3
  66. pyinfra/operations/docker.py +48 -44
  67. pyinfra/operations/files.py +47 -1
  68. pyinfra/operations/freebsd/__init__.py +12 -0
  69. pyinfra/operations/freebsd/freebsd_update.py +70 -0
  70. pyinfra/operations/freebsd/pkg.py +219 -0
  71. pyinfra/operations/freebsd/service.py +116 -0
  72. pyinfra/operations/freebsd/sysrc.py +92 -0
  73. pyinfra/operations/opkg.py +5 -5
  74. pyinfra/operations/postgres.py +99 -16
  75. pyinfra/operations/server.py +6 -4
  76. pyinfra/operations/util/docker.py +44 -22
  77. {pyinfra-3.2.dist-info → pyinfra-3.3.dist-info}/LICENSE.md +1 -1
  78. {pyinfra-3.2.dist-info → pyinfra-3.3.dist-info}/METADATA +25 -24
  79. {pyinfra-3.2.dist-info → pyinfra-3.3.dist-info}/RECORD +88 -82
  80. pyinfra_cli/exceptions.py +5 -0
  81. pyinfra_cli/log.py +3 -0
  82. pyinfra_cli/main.py +9 -8
  83. pyinfra_cli/prints.py +1 -1
  84. pyinfra_cli/virtualenv.py +1 -1
  85. tests/test_connectors/test_ssh.py +302 -182
  86. {pyinfra-3.2.dist-info → pyinfra-3.3.dist-info}/WHEEL +0 -0
  87. {pyinfra-3.2.dist-info → pyinfra-3.3.dist-info}/entry_points.txt +0 -0
  88. {pyinfra-3.2.dist-info → pyinfra-3.3.dist-info}/top_level.txt +0 -0
@@ -32,11 +32,6 @@ class PyinfraOperation(Generic[P], Protocol):
32
32
  def __call__(
33
33
  self,
34
34
  #
35
- # op args
36
- # needs to be first
37
- #
38
- *args: P.args,
39
- #
40
35
  # ConnectorArguments
41
36
  #
42
37
  # Auth
@@ -74,6 +69,10 @@ class PyinfraOperation(Generic[P], Protocol):
74
69
  _run_once: bool = False,
75
70
  _serial: bool = False,
76
71
  #
72
+ # op args
73
+ #
74
+ *args: P.args,
75
+ #
77
76
  # op kwargs
78
77
  #
79
78
  **kwargs: P.kwargs,
pyinfra/api/command.py CHANGED
@@ -6,9 +6,9 @@ from string import Formatter
6
6
  from typing import IO, TYPE_CHECKING, Callable, Union
7
7
 
8
8
  import gevent
9
- from typing_extensions import Unpack
9
+ from typing_extensions import Unpack, override
10
10
 
11
- from pyinfra.context import ctx_config, ctx_host
11
+ from pyinfra.context import LocalContextObject, ctx_config, ctx_host
12
12
 
13
13
  from .arguments import ConnectorArguments
14
14
 
@@ -58,6 +58,7 @@ class QuoteString:
58
58
  def __init__(self, obj: Union[str, "StringCommand"]):
59
59
  self.obj = obj
60
60
 
61
+ @override
61
62
  def __repr__(self) -> str:
62
63
  return f"QuoteString({self.obj})"
63
64
 
@@ -68,6 +69,7 @@ class PyinfraCommand:
68
69
  def __init__(self, **arguments: Unpack[ConnectorArguments]):
69
70
  self.connector_arguments = arguments
70
71
 
72
+ @override
71
73
  def __eq__(self, other) -> bool:
72
74
  if isinstance(other, self.__class__) and repr(self) == repr(other):
73
75
  return True
@@ -88,9 +90,11 @@ class StringCommand(PyinfraCommand):
88
90
  self.bits = bits
89
91
  self.separator = _separator
90
92
 
93
+ @override
91
94
  def __str__(self) -> str:
92
95
  return self.get_masked_value()
93
96
 
97
+ @override
94
98
  def __repr__(self) -> str:
95
99
  return f"StringCommand({self.get_masked_value()})"
96
100
 
@@ -131,6 +135,7 @@ class StringCommand(PyinfraCommand):
131
135
  ],
132
136
  )
133
137
 
138
+ @override
134
139
  def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
135
140
  connector_arguments.update(self.connector_arguments)
136
141
 
@@ -155,9 +160,11 @@ class FileUploadCommand(PyinfraCommand):
155
160
  self.dest = dest
156
161
  self.remote_temp_filename = remote_temp_filename
157
162
 
163
+ @override
158
164
  def __repr__(self):
159
165
  return "FileUploadCommand({0}, {1})".format(self.src, self.dest)
160
166
 
167
+ @override
161
168
  def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
162
169
  connector_arguments.update(self.connector_arguments)
163
170
 
@@ -184,9 +191,11 @@ class FileDownloadCommand(PyinfraCommand):
184
191
  self.dest = dest
185
192
  self.remote_temp_filename = remote_temp_filename
186
193
 
194
+ @override
187
195
  def __repr__(self):
188
196
  return "FileDownloadCommand({0}, {1})".format(self.src, self.dest)
189
197
 
198
+ @override
190
199
  def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
191
200
  connector_arguments.update(self.connector_arguments)
192
201
 
@@ -213,6 +222,7 @@ class FunctionCommand(PyinfraCommand):
213
222
  self.args = args
214
223
  self.kwargs = func_kwargs
215
224
 
225
+ @override
216
226
  def __repr__(self):
217
227
  return "FunctionCommand({0}, {1}, {2})".format(
218
228
  self.function.__name__,
@@ -220,12 +230,19 @@ class FunctionCommand(PyinfraCommand):
220
230
  self.kwargs,
221
231
  )
222
232
 
233
+ @override
223
234
  def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
224
235
  argspec = getfullargspec(self.function)
225
236
  if "state" in argspec.args and "host" in argspec.args:
226
237
  return self.function(state, host, *self.args, **self.kwargs)
227
238
 
228
- def execute_function():
239
+ # If we're already running inside a greenlet (ie a nested callback) just execute the func
240
+ # without any gevent.spawn which will break the local host object.
241
+ if isinstance(host, LocalContextObject):
242
+ self.function(*self.args, **self.kwargs)
243
+ return
244
+
245
+ def execute_function() -> None:
229
246
  with ctx_config.use(state.config.copy()):
230
247
  with ctx_host.use(host):
231
248
  self.function(*self.args, **self.kwargs)
@@ -241,9 +258,11 @@ class RsyncCommand(PyinfraCommand):
241
258
  self.dest = dest
242
259
  self.flags = flags
243
260
 
261
+ @override
244
262
  def __repr__(self):
245
263
  return "RsyncCommand({0}, {1}, {2})".format(self.src, self.dest, self.flags)
246
264
 
265
+ @override
247
266
  def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
248
267
  return host.rsync(
249
268
  self.src,
pyinfra/api/config.py CHANGED
@@ -2,6 +2,7 @@ try:
2
2
  import importlib_metadata
3
3
  except ImportError:
4
4
  import importlib.metadata as importlib_metadata # type: ignore[no-redef]
5
+
5
6
  from os import path
6
7
  from typing import Iterable, Optional, Set
7
8
 
@@ -9,6 +10,7 @@ from packaging.markers import Marker
9
10
  from packaging.requirements import Requirement
10
11
  from packaging.specifiers import SpecifierSet
11
12
  from packaging.version import Version
13
+ from typing_extensions import override
12
14
 
13
15
  from pyinfra import __version__, state
14
16
 
@@ -207,6 +209,7 @@ class Config(ConfigDefaults):
207
209
  for key, value in config.items():
208
210
  setattr(self, key, value)
209
211
 
212
+ @override
210
213
  def __setattr__(self, key, value):
211
214
  super().__setattr__(key, value)
212
215
 
@@ -221,10 +224,10 @@ class Config(ConfigDefaults):
221
224
  for key, value in config_state:
222
225
  setattr(self, key, value)
223
226
 
224
- def lock_current_state(self):
227
+ def lock_current_state(self) -> None:
225
228
  self._locked_config = self.get_current_state()
226
229
 
227
- def reset_locked_state(self):
230
+ def reset_locked_state(self) -> None:
228
231
  self.set_current_state(self._locked_config)
229
232
 
230
233
  def copy(self) -> "Config":
pyinfra/api/facts.py CHANGED
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Optional, Ty
19
19
  import click
20
20
  import gevent
21
21
  from paramiko import SSHException
22
+ from typing_extensions import override
22
23
 
23
24
  from pyinfra import logger
24
25
  from pyinfra.api import StringCommand
@@ -61,6 +62,7 @@ class FactBase(Generic[T]):
61
62
  def requires_command(self, *args, **kwargs) -> str | None:
62
63
  return None
63
64
 
65
+ @override
64
66
  def __init_subclass__(cls) -> None:
65
67
  super().__init_subclass__()
66
68
  module_name = cls.__module__.replace("pyinfra.facts.", "")
@@ -97,6 +99,7 @@ class ShortFactBase(Generic[T]):
97
99
  name: str
98
100
  fact: Type[FactBase]
99
101
 
102
+ @override
100
103
  def __init_subclass__(cls) -> None:
101
104
  super().__init_subclass__()
102
105
  module_name = cls.__module__.replace("pyinfra.facts.", "")
pyinfra/api/host.py CHANGED
@@ -17,7 +17,7 @@ from typing import (
17
17
  from uuid import uuid4
18
18
 
19
19
  import click
20
- from typing_extensions import Unpack
20
+ from typing_extensions import Unpack, override
21
21
 
22
22
  from pyinfra import logger
23
23
  from pyinfra.connectors.base import BaseConnector
@@ -75,9 +75,11 @@ class HostData:
75
75
 
76
76
  raise AttributeError(f"Host `{self.host}` has no data `{key}`")
77
77
 
78
+ @override
78
79
  def __setattr__(self, key: str, value: Any):
79
80
  self.override_datas[key] = value
80
81
 
82
+ @override
81
83
  def __str__(self):
82
84
  return str(self.datas)
83
85
 
@@ -147,8 +149,10 @@ class Host:
147
149
  name: str,
148
150
  inventory: "Inventory",
149
151
  groups,
150
- connector_cls=get_execution_connector("ssh"),
152
+ connector_cls=None,
151
153
  ):
154
+ if connector_cls is None:
155
+ connector_cls = get_execution_connector("ssh")
152
156
  self.inventory = inventory
153
157
  self.groups = groups
154
158
  self.connector_cls = connector_cls
@@ -181,9 +185,11 @@ class Host:
181
185
  padding_diff = longest_name_len - len(self.name)
182
186
  self.print_prefix_padding = "".join(" " for _ in range(0, padding_diff))
183
187
 
188
+ @override
184
189
  def __str__(self):
185
190
  return "{0}".format(self.name)
186
191
 
192
+ @override
187
193
  def __repr__(self):
188
194
  return "Host({0})".format(self.name)
189
195
 
@@ -357,7 +363,7 @@ class Host:
357
363
  # Connector proxy
358
364
  #
359
365
 
360
- def _check_state(self):
366
+ def _check_state(self) -> None:
361
367
  if not self.state:
362
368
  raise TypeError("Cannot call this function with no state!")
363
369
 
@@ -399,7 +405,7 @@ class Host:
399
405
  self.state.trigger_callbacks("host_connect", self)
400
406
  self.connected = True
401
407
 
402
- def disconnect(self):
408
+ def disconnect(self) -> None:
403
409
  """
404
410
  Disconnect from the host using it's configured connector.
405
411
  """
pyinfra/api/operation.py CHANGED
@@ -13,7 +13,7 @@ from io import StringIO
13
13
  from types import FunctionType
14
14
  from typing import TYPE_CHECKING, Any, Callable, Generator, Iterator, Optional, cast
15
15
 
16
- from typing_extensions import ParamSpec
16
+ from typing_extensions import ParamSpec, override
17
17
 
18
18
  import pyinfra
19
19
  from pyinfra import context, logger
@@ -52,6 +52,7 @@ class OperationMeta:
52
52
  self._hash = hash
53
53
  self._maybe_is_change = is_change
54
54
 
55
+ @override
55
56
  def __repr__(self) -> str:
56
57
  """
57
58
  Return Operation object as a string.
pyinfra/api/state.py CHANGED
@@ -122,7 +122,7 @@ class StateHostMeta:
122
122
  ops_no_change = 0
123
123
  op_hashes: set[str]
124
124
 
125
- def __init__(self):
125
+ def __init__(self) -> None:
126
126
  self.op_hashes = set()
127
127
 
128
128
 
@@ -86,14 +86,13 @@ class BaseConnector(abc.ABC):
86
86
  @abc.abstractmethod
87
87
  def make_names_data(name: str) -> Iterator[tuple[str, dict, list[str]]]:
88
88
  """
89
- Generates hosts/data/groups information for inventory. This allows a
90
- single connector reference to generate multiple target hosts.
89
+ Generate inventory targets. This is a staticmethod because each yield will become a new host
90
+ object with a new (ie not this) instance of the connector.
91
91
  """
92
- ...
93
92
 
94
93
  def connect(self) -> None:
95
94
  """
96
- Connect this connector instance.
95
+ Connect this connector instance. Should raise ConnectError exceptions to indicate failure.
97
96
  """
98
97
 
99
98
  def disconnect(self) -> None:
@@ -108,7 +107,20 @@ class BaseConnector(abc.ABC):
108
107
  print_output: bool,
109
108
  print_input: bool,
110
109
  **arguments: Unpack["ConnectorArguments"],
111
- ) -> tuple[bool, "CommandOutput"]: ...
110
+ ) -> tuple[bool, "CommandOutput"]:
111
+ """
112
+ Execute a command.
113
+
114
+ Args:
115
+ command (StringCommand): actual command to execute
116
+ print_output (bool): whether to print command output
117
+ print_input (bool): whether to print command input
118
+ arguments: (ConnectorArguments): connector global arguments
119
+
120
+ Returns:
121
+ tuple: (bool, CommandOutput)
122
+ Bool indicating success and CommandOutput with stdout/stderr lines.
123
+ """
112
124
 
113
125
  @abc.abstractmethod
114
126
  def put_file(
@@ -119,7 +131,14 @@ class BaseConnector(abc.ABC):
119
131
  print_output: bool = False,
120
132
  print_input: bool = False,
121
133
  **arguments: Unpack["ConnectorArguments"],
122
- ) -> bool: ...
134
+ ) -> bool:
135
+ """
136
+ Upload a local file or IO object by copying it to a temporary directory
137
+ and then writing it to the upload location.
138
+
139
+ Returns:
140
+ bool: indicating success or failure.
141
+ """
123
142
 
124
143
  @abc.abstractmethod
125
144
  def get_file(
@@ -130,9 +149,16 @@ class BaseConnector(abc.ABC):
130
149
  print_output: bool = False,
131
150
  print_input: bool = False,
132
151
  **arguments: Unpack["ConnectorArguments"],
133
- ) -> bool: ...
152
+ ) -> bool:
153
+ """
154
+ Download a local file by copying it to a temporary location and then writing
155
+ it to our filename or IO object.
156
+
157
+ Returns:
158
+ bool: indicating success or failure.
159
+ """
134
160
 
135
- def check_can_rsync(self):
161
+ def check_can_rsync(self) -> None:
136
162
  raise NotImplementedError("This connector does not support rsync")
137
163
 
138
164
  def rsync(
@@ -3,7 +3,7 @@ from tempfile import mkstemp
3
3
  from typing import TYPE_CHECKING, Optional
4
4
 
5
5
  import click
6
- from typing_extensions import Unpack
6
+ from typing_extensions import Unpack, override
7
7
 
8
8
  from pyinfra import local, logger
9
9
  from pyinfra.api import QuoteString, StringCommand
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
22
22
 
23
23
 
24
24
  @memoize
25
- def show_warning():
25
+ def show_warning() -> None:
26
26
  logger.warning("The @chroot connector is in beta!")
27
27
 
28
28
 
@@ -39,6 +39,7 @@ class ChrootConnector(BaseConnector):
39
39
  super().__init__(state, host)
40
40
  self.local = LocalConnector(state, host)
41
41
 
42
+ @override
42
43
  @staticmethod
43
44
  def make_names_data(name: Optional[str] = None):
44
45
  if not name:
@@ -50,6 +51,7 @@ class ChrootConnector(BaseConnector):
50
51
  "chroot_directory": "/{0}".format(name.lstrip("/")),
51
52
  }, ["@chroot"]
52
53
 
54
+ @override
53
55
  def connect(self) -> None:
54
56
  self.local.connect()
55
57
 
@@ -66,6 +68,7 @@ class ChrootConnector(BaseConnector):
66
68
 
67
69
  self.host.connector_data["chroot_directory"] = chroot_directory
68
70
 
71
+ @override
69
72
  def run_shell_command(
70
73
  self,
71
74
  command,
@@ -97,6 +100,7 @@ class ChrootConnector(BaseConnector):
97
100
  **local_arguments,
98
101
  )
99
102
 
103
+ @override
100
104
  def put_file(
101
105
  self,
102
106
  filename_or_io,
@@ -148,6 +152,7 @@ class ChrootConnector(BaseConnector):
148
152
 
149
153
  return status
150
154
 
155
+ @override
151
156
  def get_file(
152
157
  self,
153
158
  remote_filename,
@@ -6,7 +6,7 @@ from tempfile import mkstemp
6
6
  from typing import TYPE_CHECKING
7
7
 
8
8
  import click
9
- from typing_extensions import TypedDict, Unpack
9
+ from typing_extensions import TypedDict, Unpack, override
10
10
 
11
11
  from pyinfra import local, logger
12
12
  from pyinfra.api import QuoteString, StringCommand
@@ -115,6 +115,7 @@ class DockerConnector(BaseConnector):
115
115
  ["@docker"],
116
116
  )
117
117
 
118
+ @override
118
119
  def connect(self) -> None:
119
120
  self.local.connect()
120
121
 
@@ -127,7 +128,8 @@ class DockerConnector(BaseConnector):
127
128
  except PyinfraError:
128
129
  self.container_id = _start_docker_image(docker_identifier)
129
130
 
130
- def disconnect(self):
131
+ @override
132
+ def disconnect(self) -> None:
131
133
  container_id = self.container_id
132
134
 
133
135
  if self.no_stop:
@@ -156,6 +158,7 @@ class DockerConnector(BaseConnector):
156
158
  ),
157
159
  )
158
160
 
161
+ @override
159
162
  def run_shell_command(
160
163
  self,
161
164
  command: StringCommand,
@@ -188,6 +191,7 @@ class DockerConnector(BaseConnector):
188
191
  **local_arguments,
189
192
  )
190
193
 
194
+ @override
191
195
  def put_file(
192
196
  self,
193
197
  filename_or_io,
@@ -245,6 +249,7 @@ class DockerConnector(BaseConnector):
245
249
 
246
250
  return status
247
251
 
252
+ @override
248
253
  def get_file(
249
254
  self,
250
255
  remote_filename,
@@ -3,7 +3,7 @@ from tempfile import mkstemp
3
3
  from typing import TYPE_CHECKING
4
4
 
5
5
  import click
6
- from typing_extensions import Unpack
6
+ from typing_extensions import Unpack, override
7
7
 
8
8
  from pyinfra import logger
9
9
  from pyinfra.api import QuoteString, StringCommand
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
22
22
 
23
23
 
24
24
  @memoize
25
- def show_warning():
25
+ def show_warning() -> None:
26
26
  logger.warning("The @dockerssh connector is in beta!")
27
27
 
28
28
 
@@ -68,6 +68,7 @@ class DockerSSHConnector(BaseConnector):
68
68
  ["@dockerssh"],
69
69
  )
70
70
 
71
+ @override
71
72
  def connect(self) -> None:
72
73
  self.ssh.connect()
73
74
 
@@ -97,6 +98,7 @@ class DockerSSHConnector(BaseConnector):
97
98
 
98
99
  self.host.host_data["docker_container_id"] = container_id
99
100
 
101
+ @override
100
102
  def disconnect(self) -> None:
101
103
  container_id = self.host.host_data["docker_container_id"][:12]
102
104
 
@@ -118,6 +120,7 @@ class DockerSSHConnector(BaseConnector):
118
120
  ),
119
121
  )
120
122
 
123
+ @override
121
124
  def run_shell_command(
122
125
  self,
123
126
  command,
@@ -150,6 +153,7 @@ class DockerSSHConnector(BaseConnector):
150
153
  **local_arguments,
151
154
  )
152
155
 
156
+ @override
153
157
  def put_file(
154
158
  self,
155
159
  filename_or_io,
@@ -221,6 +225,7 @@ class DockerSSHConnector(BaseConnector):
221
225
 
222
226
  return status
223
227
 
228
+ @override
224
229
  def get_file(
225
230
  self,
226
231
  remote_filename,
@@ -4,7 +4,7 @@ from tempfile import mkstemp
4
4
  from typing import TYPE_CHECKING, Tuple
5
5
 
6
6
  import click
7
- from typing_extensions import Unpack
7
+ from typing_extensions import Unpack, override
8
8
 
9
9
  from pyinfra import logger
10
10
  from pyinfra.api.command import QuoteString, StringCommand
@@ -45,6 +45,7 @@ class LocalConnector(BaseConnector):
45
45
 
46
46
  yield "@local", {}, ["@local"]
47
47
 
48
+ @override
48
49
  def run_shell_command(
49
50
  self,
50
51
  command: StringCommand,
@@ -101,6 +102,7 @@ class LocalConnector(BaseConnector):
101
102
 
102
103
  return status, combined_output
103
104
 
105
+ @override
104
106
  def put_file(
105
107
  self,
106
108
  filename_or_io,
@@ -152,6 +154,7 @@ class LocalConnector(BaseConnector):
152
154
 
153
155
  return status
154
156
 
157
+ @override
155
158
  def get_file(
156
159
  self,
157
160
  remote_filename,
@@ -206,10 +209,12 @@ class LocalConnector(BaseConnector):
206
209
 
207
210
  return True
208
211
 
209
- def check_can_rsync(self):
212
+ @override
213
+ def check_can_rsync(self) -> None:
210
214
  if not which("rsync"):
211
215
  raise NotImplementedError("The `rsync` binary is not available on this system.")
212
216
 
217
+ @override
213
218
  def rsync(
214
219
  self,
215
220
  src,
pyinfra/connectors/ssh.py CHANGED
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Optional, Tuple
9
9
 
10
10
  import click
11
11
  from paramiko import AuthenticationException, BadHostKeyException, SFTPClient, SSHException
12
- from typing_extensions import TypedDict, Unpack
12
+ from typing_extensions import TypedDict, Unpack, override
13
13
 
14
14
  from pyinfra import logger
15
15
  from pyinfra.api.command import QuoteString, StringCommand
@@ -191,6 +191,7 @@ class SSHConnector(BaseConnector):
191
191
 
192
192
  return kwargs
193
193
 
194
+ @override
194
195
  def connect(self) -> None:
195
196
  retries = self.data["ssh_connect_retries"]
196
197
 
@@ -264,9 +265,11 @@ class SSHConnector(BaseConnector):
264
265
  f"Host key for {e.hostname} does not match.",
265
266
  )
266
267
 
268
+ @override
267
269
  def disconnect(self) -> None:
268
270
  self.get_sftp_connection.cache.clear()
269
271
 
272
+ @override
270
273
  def run_shell_command(
271
274
  self,
272
275
  command: StringCommand,
@@ -368,6 +371,7 @@ class SSHConnector(BaseConnector):
368
371
  sftp = self.get_sftp_connection()
369
372
  sftp.getfo(remote_filename, file_io)
370
373
 
374
+ @override
371
375
  def get_file(
372
376
  self,
373
377
  remote_filename: str,
@@ -454,6 +458,7 @@ class SSHConnector(BaseConnector):
454
458
  if last_e is not None:
455
459
  raise last_e
456
460
 
461
+ @override
457
462
  def put_file(
458
463
  self,
459
464
  filename_or_io,
@@ -537,7 +542,8 @@ class SSHConnector(BaseConnector):
537
542
 
538
543
  return True
539
544
 
540
- def check_can_rsync(self):
545
+ @override
546
+ def check_can_rsync(self) -> None:
541
547
  if self.data["ssh_key_password"]:
542
548
  raise NotImplementedError(
543
549
  "Rsync does not currently work with SSH keys needing passwords."
@@ -549,6 +555,7 @@ class SSHConnector(BaseConnector):
549
555
  if not which("rsync"):
550
556
  raise NotImplementedError("The `rsync` binary is not available on this system.")
551
557
 
558
+ @override
552
559
  def rsync(
553
560
  self,
554
561
  src: str,
@@ -15,6 +15,7 @@ from paramiko import (
15
15
  )
16
16
  from paramiko.agent import AgentRequestHandler
17
17
  from paramiko.hostkeys import HostKeyEntry
18
+ from typing_extensions import override
18
19
 
19
20
  from pyinfra import logger
20
21
  from pyinfra.api.util import memoize
@@ -25,6 +26,7 @@ HOST_KEYS_LOCK = BoundedSemaphore()
25
26
 
26
27
 
27
28
  class StrictPolicy(MissingHostKeyPolicy):
29
+ @override
28
30
  def missing_host_key(self, client, hostname, key):
29
31
  logger.error("No host key for {0} found in known_hosts".format(hostname))
30
32
  raise SSHException(
@@ -55,6 +57,7 @@ def append_hostkey(client, hostname, key):
55
57
 
56
58
 
57
59
  class AcceptNewPolicy(MissingHostKeyPolicy):
60
+ @override
58
61
  def missing_host_key(self, client, hostname, key):
59
62
  logger.warning(
60
63
  (
@@ -68,6 +71,7 @@ class AcceptNewPolicy(MissingHostKeyPolicy):
68
71
 
69
72
 
70
73
  class AskPolicy(MissingHostKeyPolicy):
74
+ @override
71
75
  def missing_host_key(self, client, hostname, key):
72
76
  should_continue = input(
73
77
  "No host key for {0} found in known_hosts, do you want to continue [y/n] ".format(
@@ -84,6 +88,7 @@ class AskPolicy(MissingHostKeyPolicy):
84
88
 
85
89
 
86
90
  class WarningPolicy(MissingHostKeyPolicy):
91
+ @override
87
92
  def missing_host_key(self, client, hostname, key):
88
93
  logger.warning("No host key for {0} found in known_hosts".format(hostname))
89
94
 
@@ -136,6 +141,7 @@ class SSHClient(ParamikoClient):
136
141
  original idea at http://bitprophet.org/blog/2012/11/05/gateway-solutions/.
137
142
  """
138
143
 
144
+ @override
139
145
  def connect( # type: ignore[override]
140
146
  self,
141
147
  hostname,
@@ -177,6 +183,13 @@ class SSHClient(ParamikoClient):
177
183
  if _pyinfra_ssh_forward_agent is not None:
178
184
  forward_agent = _pyinfra_ssh_forward_agent
179
185
 
186
+ keep_alive = config.get("keep_alive")
187
+
188
+ if keep_alive:
189
+ transport = self.get_transport()
190
+ assert transport is not None, "No transport"
191
+ transport.set_keepalive(keep_alive)
192
+
180
193
  if forward_agent:
181
194
  transport = self.get_transport()
182
195
  assert transport is not None, "No transport"
@@ -234,6 +247,9 @@ class SSHClient(ParamikoClient):
234
247
  if "port" in host_config:
235
248
  cfg["port"] = int(host_config["port"])
236
249
 
250
+ if "serveraliveinterval" in host_config:
251
+ cfg["keep_alive"] = int(host_config["serveraliveinterval"])
252
+
237
253
  if "proxycommand" in host_config:
238
254
  cfg["sock"] = ProxyCommand(host_config["proxycommand"])
239
255