pyinfra 3.5.1__py3-none-any.whl → 3.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. pyinfra/api/__init__.py +1 -0
  2. pyinfra/api/arguments.py +9 -2
  3. pyinfra/api/arguments_typed.py +18 -23
  4. pyinfra/api/command.py +9 -3
  5. pyinfra/api/deploy.py +1 -1
  6. pyinfra/api/exceptions.py +12 -0
  7. pyinfra/api/facts.py +20 -4
  8. pyinfra/api/host.py +3 -0
  9. pyinfra/api/inventory.py +2 -2
  10. pyinfra/api/metadata.py +69 -0
  11. pyinfra/api/operation.py +9 -4
  12. pyinfra/api/operations.py +16 -14
  13. pyinfra/api/util.py +22 -5
  14. pyinfra/connectors/docker.py +25 -1
  15. pyinfra/connectors/ssh.py +57 -0
  16. pyinfra/connectors/sshuserclient/client.py +47 -28
  17. pyinfra/connectors/util.py +16 -9
  18. pyinfra/facts/crontab.py +10 -8
  19. pyinfra/facts/files.py +12 -3
  20. pyinfra/facts/flatpak.py +1 -1
  21. pyinfra/facts/npm.py +1 -1
  22. pyinfra/facts/server.py +18 -2
  23. pyinfra/operations/apk.py +2 -1
  24. pyinfra/operations/apt.py +15 -7
  25. pyinfra/operations/brew.py +1 -0
  26. pyinfra/operations/crontab.py +4 -1
  27. pyinfra/operations/dnf.py +4 -1
  28. pyinfra/operations/docker.py +70 -16
  29. pyinfra/operations/files.py +87 -12
  30. pyinfra/operations/flatpak.py +1 -0
  31. pyinfra/operations/gem.py +1 -0
  32. pyinfra/operations/git.py +1 -0
  33. pyinfra/operations/iptables.py +1 -0
  34. pyinfra/operations/lxd.py +1 -0
  35. pyinfra/operations/mysql.py +1 -0
  36. pyinfra/operations/opkg.py +2 -1
  37. pyinfra/operations/pacman.py +1 -0
  38. pyinfra/operations/pip.py +1 -0
  39. pyinfra/operations/pipx.py +1 -0
  40. pyinfra/operations/pkg.py +1 -0
  41. pyinfra/operations/pkgin.py +1 -0
  42. pyinfra/operations/postgres.py +7 -1
  43. pyinfra/operations/puppet.py +1 -0
  44. pyinfra/operations/python.py +1 -0
  45. pyinfra/operations/selinux.py +1 -0
  46. pyinfra/operations/server.py +1 -0
  47. pyinfra/operations/snap.py +2 -1
  48. pyinfra/operations/ssh.py +1 -0
  49. pyinfra/operations/systemd.py +1 -0
  50. pyinfra/operations/sysvinit.py +2 -1
  51. pyinfra/operations/util/docker.py +172 -8
  52. pyinfra/operations/util/packaging.py +2 -0
  53. pyinfra/operations/xbps.py +1 -0
  54. pyinfra/operations/yum.py +4 -1
  55. pyinfra/operations/zfs.py +1 -0
  56. pyinfra/operations/zypper.py +1 -0
  57. {pyinfra-3.5.1.dist-info → pyinfra-3.6.1.dist-info}/METADATA +2 -1
  58. {pyinfra-3.5.1.dist-info → pyinfra-3.6.1.dist-info}/RECORD +64 -63
  59. {pyinfra-3.5.1.dist-info → pyinfra-3.6.1.dist-info}/WHEEL +1 -1
  60. pyinfra_cli/cli.py +20 -4
  61. pyinfra_cli/inventory.py +26 -1
  62. pyinfra_cli/util.py +1 -1
  63. {pyinfra-3.5.1.dist-info → pyinfra-3.6.1.dist-info}/entry_points.txt +0 -0
  64. {pyinfra-3.5.1.dist-info → pyinfra-3.6.1.dist-info}/licenses/LICENSE.md +0 -0
pyinfra/api/__init__.py CHANGED
@@ -14,6 +14,7 @@ from .exceptions import ( # noqa: F401
14
14
  FactError,
15
15
  FactTypeError,
16
16
  FactValueError,
17
+ FactProcessError,
17
18
  InventoryError,
18
19
  OperationError,
19
20
  OperationTypeError,
pyinfra/api/arguments.py CHANGED
@@ -70,12 +70,15 @@ class ConnectorArguments(TypedDict, total=False):
70
70
  _success_exit_codes: Iterable[int]
71
71
  _timeout: int
72
72
  _get_pty: bool
73
- _stdin: Union[str, Iterable[str]]
73
+ _stdin: Union[str, list[str], Iterable[str]]
74
74
 
75
75
  # Retry arguments
76
76
  _retries: int
77
77
  _retry_delay: Union[int, float]
78
- _retry_until: Optional[Callable[[dict], bool]]
78
+ _retry_until: Callable[[dict], bool]
79
+
80
+ # Temp directory argument
81
+ _temp_dir: str
79
82
 
80
83
 
81
84
  def generate_env(config: "Config", value: dict) -> dict:
@@ -163,6 +166,10 @@ shell_argument_meta: dict[str, ArgumentMeta] = {
163
166
  "String or buffer to send to the stdin of any commands.",
164
167
  default=lambda _: None,
165
168
  ),
169
+ "_temp_dir": ArgumentMeta(
170
+ "Temporary directory on the remote host for file operations.",
171
+ default=lambda config: config.TEMP_DIR,
172
+ ),
166
173
  }
167
174
 
168
175
 
@@ -1,16 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import (
4
- TYPE_CHECKING,
5
- Callable,
6
- Generator,
7
- Generic,
8
- Iterable,
9
- List,
10
- Mapping,
11
- Optional,
12
- Union,
13
- )
3
+ from typing import TYPE_CHECKING, Callable, Generator, Generic, Iterable, List, Mapping, Union
14
4
 
15
5
  from typing_extensions import ParamSpec, Protocol
16
6
 
@@ -36,36 +26,41 @@ class PyinfraOperation(Generic[P], Protocol):
36
26
  #
37
27
  # Auth
38
28
  _sudo: bool = False,
39
- _sudo_user: Optional[str] = None,
29
+ _sudo_user: None | str = None,
40
30
  _use_sudo_login: bool = False,
41
- _sudo_password: Optional[str] = None,
31
+ _sudo_password: None | str = None,
42
32
  _preserve_sudo_env: bool = False,
43
- _su_user: Optional[str] = None,
33
+ _su_user: None | str = None,
44
34
  _use_su_login: bool = False,
45
35
  _preserve_su_env: bool = False,
46
- _su_shell: Optional[str] = None,
36
+ _su_shell: None | str = None,
47
37
  _doas: bool = False,
48
- _doas_user: Optional[str] = None,
38
+ _doas_user: None | str = None,
49
39
  # Shell arguments
50
- _shell_executable: Optional[str] = None,
51
- _chdir: Optional[str] = None,
52
- _env: Optional[Mapping[str, str]] = None,
40
+ _shell_executable: None | str = None,
41
+ _chdir: None | str = None,
42
+ _env: None | Mapping[str, str] = None,
53
43
  # Connector control
54
44
  _success_exit_codes: Iterable[int] = (0,),
55
- _timeout: Optional[int] = None,
45
+ _timeout: None | int = None,
56
46
  _get_pty: bool = False,
57
- _stdin: Union[None, str, list[str], tuple[str, ...]] = None,
47
+ _stdin: None | Union[str, list[str], Iterable[str]] = None,
48
+ # Retry arguments
49
+ _retries: None | int = None,
50
+ _retry_delay: None | Union[int, float] = None,
51
+ _retry_until: None | Callable[[dict], bool] = None,
52
+ _temp_dir: None | str = None,
58
53
  #
59
54
  # MetaArguments
60
55
  #
61
- name: Optional[str] = None,
56
+ name: None | str = None,
62
57
  _ignore_errors: bool = False,
63
58
  _continue_on_error: bool = False,
64
59
  _if: Union[List[Callable[[], bool]], Callable[[], bool], None] = None,
65
60
  #
66
61
  # ExecutionArguments
67
62
  #
68
- _parallel: Optional[int] = None,
63
+ _parallel: None | int = None,
69
64
  _run_once: bool = False,
70
65
  _serial: bool = False,
71
66
  #
pyinfra/api/command.py CHANGED
@@ -242,13 +242,19 @@ class FunctionCommand(PyinfraCommand):
242
242
  self.function(*self.args, **self.kwargs)
243
243
  return
244
244
 
245
- def execute_function() -> None:
245
+ def execute_function() -> None | Exception:
246
246
  with ctx_config.use(state.config.copy()):
247
247
  with ctx_host.use(host):
248
- self.function(*self.args, **self.kwargs)
248
+ try:
249
+ self.function(*self.args, **self.kwargs)
250
+ except Exception as e:
251
+ return e
252
+ return None
249
253
 
250
254
  greenlet = gevent.spawn(execute_function)
251
- return greenlet.get()
255
+ exception = greenlet.get()
256
+ if exception is not None:
257
+ raise exception
252
258
 
253
259
 
254
260
  class RsyncCommand(PyinfraCommand):
pyinfra/api/deploy.py CHANGED
@@ -41,7 +41,7 @@ def add_deploy(state: "State", deploy_func: Callable[..., Any], *args, **kwargs)
41
41
  ).format(get_call_location()),
42
42
  )
43
43
 
44
- hosts = kwargs.pop("host", state.inventory.iter_active_hosts())
44
+ hosts = kwargs.pop("host", state.inventory.get_active_hosts())
45
45
  if isinstance(hosts, Host):
46
46
  hosts = [hosts]
47
47
 
pyinfra/api/exceptions.py CHANGED
@@ -29,6 +29,12 @@ class FactValueError(FactError, ValueError):
29
29
  """
30
30
 
31
31
 
32
+ class FactProcessError(FactError, RuntimeError):
33
+ """
34
+ Exception raised when the data gathered for a fact cannot be processed.
35
+ """
36
+
37
+
32
38
  class OperationError(PyinfraError):
33
39
  """
34
40
  Exception raised during fact gathering staging if an operation is unable to
@@ -48,6 +54,12 @@ class OperationValueError(OperationError, ValueError):
48
54
  """
49
55
 
50
56
 
57
+ class NestedOperationError(OperationError):
58
+ """
59
+ Exception raised when a nested (immediately executed) operation fails.
60
+ """
61
+
62
+
51
63
  class DeployError(PyinfraError):
52
64
  """
53
65
  User exception for raising in deploys or sub deploys.
pyinfra/api/facts.py CHANGED
@@ -14,7 +14,7 @@ import inspect
14
14
  import re
15
15
  from inspect import getcallargs
16
16
  from socket import error as socket_error, timeout as timeout_error
17
- from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Optional, Type, TypeVar, cast
17
+ from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, Type, TypeVar, cast
18
18
 
19
19
  import click
20
20
  import gevent
@@ -24,6 +24,7 @@ from typing_extensions import override
24
24
  from pyinfra import logger
25
25
  from pyinfra.api import StringCommand
26
26
  from pyinfra.api.arguments import all_global_arguments, pop_global_arguments
27
+ from pyinfra.api.exceptions import FactProcessError
27
28
  from pyinfra.api.util import (
28
29
  get_kwargs_str,
29
30
  log_error_or_warning,
@@ -86,7 +87,7 @@ class FactBase(Generic[T]):
86
87
 
87
88
  return cast(T, None)
88
89
 
89
- def process(self, output: Iterable[str]) -> T:
90
+ def process(self, output: list[str]) -> T:
90
91
  # NOTE: TypeVar does not support a default, so we have to cast this str -> T
91
92
  return cast(T, "\n".join(output))
92
93
 
@@ -151,7 +152,7 @@ def get_facts(state, *args, **kwargs):
151
152
  with ctx_state.use(state):
152
153
  greenlet_to_host = {
153
154
  state.pool.spawn(get_host_fact, host, *args, **kwargs): host
154
- for host in state.inventory.iter_active_hosts()
155
+ for host in state.inventory.get_active_hosts()
155
156
  }
156
157
 
157
158
  results = {}
@@ -269,7 +270,22 @@ def _get_fact(
269
270
 
270
271
  if status:
271
272
  if stdout_lines:
272
- data = fact.process(stdout_lines)
273
+ try:
274
+ data = fact.process(stdout_lines)
275
+ except FactProcessError as e:
276
+ log_error_or_warning(
277
+ host,
278
+ global_kwargs["_ignore_errors"],
279
+ description=("could not process fact: {0} {1}").format(
280
+ name, get_kwargs_str(fact_kwargs)
281
+ ),
282
+ exception=e,
283
+ )
284
+
285
+ # Check we've not failed
286
+ if apply_failed_hosts and not global_kwargs["_ignore_errors"]:
287
+ state.fail_hosts({host})
288
+
273
289
  elif stderr_lines:
274
290
  # If we have error output and that error is sudo or su stating the user
275
291
  # does not exist, do not fail but instead return the default fact value.
pyinfra/api/host.py CHANGED
@@ -328,6 +328,9 @@ class Host:
328
328
 
329
329
  return temp_directory
330
330
 
331
+ def get_temp_dir_config(self):
332
+ return self.state.config.TEMP_DIR or self.state.config.DEFAULT_TEMP_DIR
333
+
331
334
  def get_temp_filename(
332
335
  self,
333
336
  hash_key: Optional[str] = None,
pyinfra/api/inventory.py CHANGED
@@ -158,11 +158,11 @@ class Inventory:
158
158
 
159
159
  return iter(self.hosts.values())
160
160
 
161
- def iter_active_hosts(self) -> Iterator["Host"]:
161
+ def get_active_hosts(self) -> list["Host"]:
162
162
  """
163
163
  Iterates over active inventory hosts.
164
164
  """
165
- return iter(self.state.active_hosts)
165
+ return list(self.state.active_hosts)
166
166
 
167
167
  def len_active_hosts(self) -> int:
168
168
  """
@@ -0,0 +1,69 @@
1
+ """
2
+ Support parsing pyinfra-metadata.toml
3
+
4
+ Currently just parses plugins and their metadata.
5
+ """
6
+
7
+ import tomllib
8
+ from typing import Literal, get_args
9
+
10
+ from pydantic import BaseModel, TypeAdapter, field_validator
11
+
12
+ AllowedTagType = Literal[
13
+ "boot",
14
+ "containers",
15
+ "database",
16
+ "service-management",
17
+ "package-manager",
18
+ "python",
19
+ "ruby",
20
+ "javascript",
21
+ "configuration-management",
22
+ "security",
23
+ "storage",
24
+ "system",
25
+ "rust",
26
+ "version-control-system",
27
+ ]
28
+
29
+
30
+ class Tag(BaseModel):
31
+ """Representation of a plugin tag."""
32
+
33
+ value: AllowedTagType
34
+
35
+ @field_validator("value", mode="before")
36
+ def _validate_value(cls, v) -> AllowedTagType:
37
+ allowed_tags = set(get_args(AllowedTagType))
38
+ if v not in allowed_tags:
39
+ raise ValueError(f"Invalid tag: {v}. Allowed: {allowed_tags}")
40
+ return v
41
+
42
+ @property
43
+ def title_case(self) -> str:
44
+ return " ".join([t.title() for t in self.value.split("-")])
45
+
46
+
47
+ ALLOWED_TAGS = [Tag(value=tag) for tag in set(get_args(AllowedTagType))]
48
+
49
+
50
+ class Plugin(BaseModel):
51
+ """Representation of a pyinfra plugin."""
52
+
53
+ name: str
54
+ # description: str # FUTURE we should grab these from doc strings
55
+ path: str
56
+ type: Literal["operation", "fact", "connector", "deploy"]
57
+ tags: list[Tag]
58
+
59
+ @field_validator("tags", mode="before")
60
+ def _wrap_tags(cls, v):
61
+ return [Tag(value=tag) if not isinstance(tag, Tag) else tag for tag in v]
62
+
63
+
64
+ def parse_plugins(metadata_text: str) -> list[Plugin]:
65
+ """Given the contents of a pyinfra-metadata.toml parse out the plugins."""
66
+ pyinfra_metadata = tomllib.loads(metadata_text).get("pyinfra", None)
67
+ if not pyinfra_metadata:
68
+ raise ValueError("Missing [pyinfra.plugins] section in pyinfra-metadata.toml")
69
+ return TypeAdapter(list[Plugin]).validate_python(pyinfra_metadata["plugins"].values())
pyinfra/api/operation.py CHANGED
@@ -22,7 +22,7 @@ from pyinfra.context import ctx_host, ctx_state
22
22
  from .arguments import EXECUTION_KWARG_KEYS, AllArguments, pop_global_arguments
23
23
  from .arguments_typed import PyinfraOperation
24
24
  from .command import PyinfraCommand, StringCommand
25
- from .exceptions import OperationValueError, PyinfraError
25
+ from .exceptions import NestedOperationError, OperationValueError, PyinfraError
26
26
  from .host import Host
27
27
  from .operations import run_host_op
28
28
  from .state import State, StateOperationHostData, StateOperationMeta, StateStage
@@ -221,7 +221,7 @@ def add_op(state: State, op_func, *args, **kwargs):
221
221
  ),
222
222
  )
223
223
 
224
- hosts = kwargs.pop("host", state.inventory.iter_active_hosts())
224
+ hosts = kwargs.pop("host", state.inventory.get_active_hosts())
225
225
  if isinstance(hosts, Host):
226
226
  hosts = [hosts]
227
227
 
@@ -266,7 +266,9 @@ def _wrap_operation(func: Callable[P, Generator], _set_in_op: bool = True) -> Py
266
266
  state = context.state
267
267
  host = context.host
268
268
 
269
- if state.current_stage < StateStage.Prepare or state.current_stage > StateStage.Execute:
269
+ if pyinfra.is_cli and (
270
+ state.current_stage < StateStage.Prepare or state.current_stage > StateStage.Execute
271
+ ):
270
272
  raise Exception("Cannot call operations outside of Prepare/Execute stages")
271
273
 
272
274
  if host.in_op:
@@ -470,8 +472,11 @@ def execute_immediately(state, host, op_hash):
470
472
  op_meta = state.get_op_meta(op_hash)
471
473
  op_data = state.get_op_data_for_host(host, op_hash)
472
474
  op_data.parent_op_hash = host.executing_op_hash
475
+
473
476
  log_operation_start(op_meta, op_types=["nested"], prefix="")
474
- run_host_op(state, host, op_hash)
477
+
478
+ if run_host_op(state, host, op_hash) is False:
479
+ raise NestedOperationError(op_hash)
475
480
 
476
481
 
477
482
  def _get_arg_value(arg):
pyinfra/api/operations.py CHANGED
@@ -4,7 +4,7 @@ import time
4
4
  import traceback
5
5
  from itertools import product
6
6
  from socket import error as socket_error, timeout as timeout_error
7
- from typing import TYPE_CHECKING, Optional, cast
7
+ from typing import TYPE_CHECKING, cast
8
8
 
9
9
  import click
10
10
  import gevent
@@ -17,7 +17,7 @@ from pyinfra.progress import progress_spinner
17
17
 
18
18
  from .arguments import CONNECTOR_ARGUMENT_KEYS, ConnectorArguments
19
19
  from .command import FunctionCommand, PyinfraCommand, StringCommand
20
- from .exceptions import PyinfraError
20
+ from .exceptions import NestedOperationError, PyinfraError
21
21
  from .util import (
22
22
  format_exception,
23
23
  log_error_or_warning,
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
35
35
  #
36
36
 
37
37
 
38
- def run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]:
38
+ def run_host_op(state: "State", host: "Host", op_hash: str) -> bool:
39
39
  state.trigger_callbacks("operation_host_start", host, op_hash)
40
40
 
41
41
  if op_hash not in state.ops[host]:
@@ -59,7 +59,7 @@ def run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]:
59
59
  host.executing_op_hash = None
60
60
 
61
61
 
62
- def _run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]:
62
+ def _run_host_op(state: "State", host: "Host", op_hash: str) -> bool:
63
63
  op_data = state.get_op_data_for_host(host, op_hash)
64
64
  global_arguments = op_data.global_arguments
65
65
 
@@ -104,6 +104,8 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]:
104
104
  if isinstance(command, FunctionCommand):
105
105
  try:
106
106
  status = command.execute(state, host, connector_arguments)
107
+ except NestedOperationError:
108
+ host.log_styled("Error in nested operation", fg="red", log_func=logger.error)
107
109
  except Exception as e:
108
110
  # Custom functions could do anything, so expect anything!
109
111
  logger.warning(traceback.format_exc())
@@ -194,11 +196,11 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> Optional[bool]:
194
196
  host_results.ops += 1
195
197
  host_results.success_ops += 1
196
198
 
197
- _status_log = "Success" if executed_commands > 0 else "No changes"
199
+ _status_text = "Success" if executed_commands > 0 else "No changes"
198
200
  if retry_attempt > 0:
199
- _status_log = f"{_status_log} on retry {retry_attempt}"
201
+ _status_text = f"{_status_text} on retry {retry_attempt}"
200
202
 
201
- _click_log_status = click.style(_status_log, "green")
203
+ _click_log_status = click.style(_status_text, "green" if executed_commands > 0 else "cyan")
202
204
  logger.info("{0}{1}".format(host.print_prefix, _click_log_status))
203
205
 
204
206
  state.trigger_callbacks("operation_host_success", host, op_hash, retry_attempt)
@@ -278,7 +280,7 @@ def _run_serial_ops(state: "State"):
278
280
  Run all ops for all servers, one server at a time.
279
281
  """
280
282
 
281
- for host in list(state.inventory.iter_active_hosts()):
283
+ for host in list(state.inventory.get_active_hosts()):
282
284
  host_operations = product([host], state.get_op_order())
283
285
  with progress_spinner(host_operations) as progress:
284
286
  try:
@@ -296,7 +298,7 @@ def _run_no_wait_ops(state: "State"):
296
298
  Run all ops for all servers at once.
297
299
  """
298
300
 
299
- hosts_operations = product(state.inventory.iter_active_hosts(), state.get_op_order())
301
+ hosts_operations = product(state.inventory.get_active_hosts(), state.get_op_order())
300
302
  with progress_spinner(hosts_operations) as progress:
301
303
  # Spawn greenlet for each host to run *all* ops
302
304
  if state.pool is None:
@@ -308,7 +310,7 @@ def _run_no_wait_ops(state: "State"):
308
310
  host,
309
311
  progress=progress,
310
312
  )
311
- for host in state.inventory.iter_active_hosts()
313
+ for host in state.inventory.get_active_hosts()
312
314
  ]
313
315
  gevent.joinall(greenlets)
314
316
 
@@ -326,9 +328,9 @@ def _run_single_op(state: "State", op_hash: str):
326
328
  failed_hosts = set()
327
329
 
328
330
  if op_meta.global_arguments["_serial"]:
329
- with progress_spinner(state.inventory.iter_active_hosts()) as progress:
331
+ with progress_spinner(state.inventory.get_active_hosts()) as progress:
330
332
  # For each host, run the op
331
- for host in state.inventory.iter_active_hosts():
333
+ for host in state.inventory.get_active_hosts():
332
334
  result = _run_host_op_with_context(state, host, op_hash)
333
335
  progress(host)
334
336
 
@@ -337,12 +339,12 @@ def _run_single_op(state: "State", op_hash: str):
337
339
 
338
340
  else:
339
341
  # Start with the whole inventory in one batch
340
- batches = [list(state.inventory.iter_active_hosts())]
342
+ batches = [list(state.inventory.get_active_hosts())]
341
343
 
342
344
  # If parallel set break up the inventory into a series of batches
343
345
  parallel = op_meta.global_arguments["_parallel"]
344
346
  if parallel:
345
- hosts = list(state.inventory.iter_active_hosts())
347
+ hosts = list(state.inventory.get_active_hosts())
346
348
  batches = [hosts[i : i + parallel] for i in range(0, len(hosts), parallel)]
347
349
 
348
350
  for batch in batches:
pyinfra/api/util.py CHANGED
@@ -10,7 +10,7 @@ from socket import error as socket_error, timeout as timeout_error
10
10
  from typing import IO, TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
11
11
 
12
12
  import click
13
- from jinja2 import Environment, FileSystemLoader, StrictUndefined
13
+ from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template
14
14
  from paramiko import SSHException
15
15
  from typeguard import TypeCheckError, check_type
16
16
 
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
26
26
  BLOCKSIZE = 65536
27
27
 
28
28
  # Caches
29
- TEMPLATES: Dict[Any, Any] = {}
29
+ TEMPLATES: Dict[str, Template] = {}
30
30
  FILE_SHAS: Dict[Any, Any] = {}
31
31
 
32
32
  PYINFRA_INSTALL_DIR = path.normpath(path.join(path.dirname(__file__), ".."))
@@ -139,7 +139,9 @@ def get_operation_order_from_stack(state: "State"):
139
139
  return line_numbers
140
140
 
141
141
 
142
- def get_template(filename_or_io: str | IO, jinja_env_kwargs: dict[str, Any] | None = None):
142
+ def get_template(
143
+ filename_or_io: str | IO, jinja_env_kwargs: dict[str, Any] | None = None
144
+ ) -> Template:
143
145
  """
144
146
  Gets a jinja2 ``Template`` object for the input filename or string, with caching
145
147
  based on the filename of the template, or the SHA1 of the input string.
@@ -155,10 +157,11 @@ def get_template(filename_or_io: str | IO, jinja_env_kwargs: dict[str, Any] | No
155
157
  with file_data as file_io:
156
158
  template_string = file_io.read()
157
159
 
160
+ default_loader = FileSystemLoader(getcwd())
158
161
  template = Environment(
159
162
  undefined=StrictUndefined,
160
163
  keep_trailing_newline=True,
161
- loader=FileSystemLoader(getcwd()),
164
+ loader=jinja_env_kwargs.pop("loader", default_loader),
162
165
  **jinja_env_kwargs,
163
166
  ).from_string(template_string)
164
167
 
@@ -219,7 +222,11 @@ def log_operation_start(
219
222
 
220
223
 
221
224
  def log_error_or_warning(
222
- host: "Host", ignore_errors: bool, description: str = "", continue_on_error: bool = False
225
+ host: "Host",
226
+ ignore_errors: bool,
227
+ description: str = "",
228
+ continue_on_error: bool = False,
229
+ exception: Exception | None = None,
223
230
  ) -> None:
224
231
  log_func = logger.error
225
232
  log_color = "red"
@@ -234,6 +241,16 @@ def log_error_or_warning(
234
241
  if description:
235
242
  log_text = f"{log_text}: "
236
243
 
244
+ if exception:
245
+ exc = exception.__cause__ or exception
246
+ exc_text = "{0}: {1}".format(type(exc).__name__, exc)
247
+ log_func(
248
+ "{0}{1}".format(
249
+ host.print_prefix,
250
+ click.style(exc_text, log_color),
251
+ ),
252
+ )
253
+
237
254
  log_func(
238
255
  "{0}{1}{2}".format(
239
256
  host.print_prefix,
@@ -26,10 +26,14 @@ if TYPE_CHECKING:
26
26
 
27
27
  class ConnectorData(TypedDict):
28
28
  docker_identifier: str
29
+ docker_platform: str
30
+ docker_architecture: str
29
31
 
30
32
 
31
33
  connector_data_meta: dict[str, DataMeta] = {
32
34
  "docker_identifier": DataMeta("ID of container or image to start from"),
35
+ "docker_platform": DataMeta("Platform to use for Docker image (e.g., linux/amd64)"),
36
+ "docker_architecture": DataMeta("Architecture to use for Docker image (e.g., amd64, arm64)"),
33
37
  }
34
38
 
35
39
 
@@ -108,9 +112,29 @@ class DockerConnector(BaseConnector):
108
112
  return container_id, True
109
113
 
110
114
  def _start_docker_image(self, image_name):
115
+ docker_cmd_parts = [
116
+ self.docker_cmd,
117
+ "run",
118
+ "-d",
119
+ ]
120
+
121
+ if self.data.get("docker_platform"):
122
+ docker_cmd_parts.extend(["--platform", self.data["docker_platform"]])
123
+ if self.data.get("docker_architecture"):
124
+ docker_cmd_parts.extend(["--arch", self.data["docker_architecture"]])
125
+
126
+ docker_cmd_parts.extend(
127
+ [
128
+ image_name,
129
+ "tail",
130
+ "-f",
131
+ "/dev/null",
132
+ ]
133
+ )
134
+
111
135
  try:
112
136
  return local.shell(
113
- f"{self.docker_cmd} run -d {image_name} tail -f /dev/null",
137
+ " ".join(docker_cmd_parts),
114
138
  splitlines=True,
115
139
  )[-1] # last line is the container ID
116
140
  except PyinfraError as e:
pyinfra/connectors/ssh.py CHANGED
@@ -9,6 +9,7 @@ from typing import IO, TYPE_CHECKING, Any, Iterable, Optional, Protocol, Tuple
9
9
 
10
10
  import click
11
11
  from paramiko import AuthenticationException, BadHostKeyException, SFTPClient, SSHException
12
+ from paramiko.agent import Agent
12
13
  from typing_extensions import TypedDict, Unpack, override
13
14
 
14
15
  from pyinfra import logger
@@ -286,10 +287,64 @@ class SSHConnector(BaseConnector):
286
287
  f"Host key for {e.hostname} does not match.",
287
288
  )
288
289
 
290
+ except SSHException as e:
291
+ if self._retry_paramiko_agent_keys(hostname, kwargs, e):
292
+ return
293
+ raise
294
+
289
295
  @override
290
296
  def disconnect(self) -> None:
291
297
  self.get_file_transfer_connection.cache.clear()
292
298
 
299
+ def _retry_paramiko_agent_keys(
300
+ self,
301
+ hostname: str,
302
+ kwargs: dict[str, Any],
303
+ error: SSHException,
304
+ ) -> bool:
305
+ # Workaround for Paramiko multi-key bug (paramiko/paramiko#1390).
306
+ if "no existing session" not in str(error).lower():
307
+ return False
308
+
309
+ if not kwargs.get("allow_agent"):
310
+ return False
311
+
312
+ try:
313
+ agent_keys = list(Agent().get_keys())
314
+ except Exception:
315
+ return False
316
+
317
+ if not agent_keys:
318
+ return False
319
+
320
+ # Skip the first agent key, since Paramiko already attempted it
321
+ attempt_keys = agent_keys[1:] if len(agent_keys) > 1 else agent_keys
322
+
323
+ for agent_key in attempt_keys:
324
+ if self.client is not None:
325
+ try:
326
+ self.client.close()
327
+ except Exception:
328
+ pass
329
+
330
+ self.client = SSHClient()
331
+
332
+ single_key_kwargs = dict(kwargs)
333
+ single_key_kwargs["allow_agent"] = False
334
+ single_key_kwargs["pkey"] = agent_key
335
+
336
+ try:
337
+ self.client.connect(hostname, **single_key_kwargs)
338
+ return True
339
+ except AuthenticationException:
340
+ continue
341
+ except SSHException as retry_error:
342
+ if "no existing session" in str(retry_error).lower():
343
+ continue
344
+ raise retry_error
345
+
346
+ return False
347
+
293
348
  @override
294
349
  def run_shell_command(
295
350
  self,
@@ -342,8 +397,10 @@ class SSHConnector(BaseConnector):
342
397
  get_pty=_get_pty,
343
398
  )
344
399
 
400
+ # Write any stdin and then close it
345
401
  if _stdin:
346
402
  write_stdin(_stdin, stdin_buffer)
403
+ stdin_buffer.close()
347
404
 
348
405
  combined_output = read_output_buffers(
349
406
  stdout_buffer,