pyinfra 3.0b0__py2.py3-none-any.whl → 3.0b1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. pyinfra/api/__init__.py +3 -0
  2. pyinfra/api/arguments.py +5 -4
  3. pyinfra/api/arguments_typed.py +12 -2
  4. pyinfra/api/exceptions.py +19 -0
  5. pyinfra/api/facts.py +1 -1
  6. pyinfra/api/host.py +46 -7
  7. pyinfra/api/operation.py +77 -39
  8. pyinfra/api/operations.py +10 -11
  9. pyinfra/api/state.py +11 -2
  10. pyinfra/connectors/base.py +1 -1
  11. pyinfra/connectors/chroot.py +5 -6
  12. pyinfra/connectors/docker.py +11 -10
  13. pyinfra/connectors/dockerssh.py +5 -4
  14. pyinfra/connectors/local.py +5 -5
  15. pyinfra/connectors/ssh.py +44 -23
  16. pyinfra/connectors/terraform.py +9 -6
  17. pyinfra/connectors/util.py +1 -1
  18. pyinfra/connectors/vagrant.py +6 -5
  19. pyinfra/facts/choco.py +1 -1
  20. pyinfra/facts/deb.py +2 -2
  21. pyinfra/facts/postgres.py +168 -0
  22. pyinfra/facts/postgresql.py +5 -164
  23. pyinfra/facts/systemd.py +26 -10
  24. pyinfra/operations/files.py +5 -3
  25. pyinfra/operations/iptables.py +6 -0
  26. pyinfra/operations/pip.py +5 -0
  27. pyinfra/operations/postgres.py +347 -0
  28. pyinfra/operations/postgresql.py +17 -336
  29. pyinfra/operations/systemd.py +5 -3
  30. {pyinfra-3.0b0.dist-info → pyinfra-3.0b1.dist-info}/METADATA +6 -6
  31. {pyinfra-3.0b0.dist-info → pyinfra-3.0b1.dist-info}/RECORD +44 -43
  32. pyinfra_cli/commands.py +3 -2
  33. pyinfra_cli/exceptions.py +5 -0
  34. pyinfra_cli/main.py +2 -0
  35. pyinfra_cli/prints.py +22 -104
  36. tests/test_api/test_api_deploys.py +5 -5
  37. tests/test_api/test_api_operations.py +4 -4
  38. tests/test_connectors/test_ssh.py +52 -0
  39. tests/test_connectors/test_terraform.py +11 -8
  40. tests/test_connectors/test_vagrant.py +3 -3
  41. pyinfra_cli/inventory_dsl.py +0 -23
  42. {pyinfra-3.0b0.dist-info → pyinfra-3.0b1.dist-info}/LICENSE.md +0 -0
  43. {pyinfra-3.0b0.dist-info → pyinfra-3.0b1.dist-info}/WHEEL +0 -0
  44. {pyinfra-3.0b0.dist-info → pyinfra-3.0b1.dist-info}/entry_points.txt +0 -0
  45. {pyinfra-3.0b0.dist-info → pyinfra-3.0b1.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,8 @@ if TYPE_CHECKING:
25
25
 
26
26
  class LocalConnector(BaseConnector):
27
27
  """
28
- The ``@local`` connector executes changes on the local machine using subprocesses. **This connector is only compatible with MacOS & Linux hosts**.
28
+ The ``@local`` connector executes changes on the local machine using
29
+ subprocesses. **This connector is only compatible with MacOS & Linux hosts**.
29
30
 
30
31
  Examples:
31
32
 
@@ -38,8 +39,8 @@ class LocalConnector(BaseConnector):
38
39
  handles_execution = True
39
40
 
40
41
  @staticmethod
41
- def make_names_data(_=None):
42
- if _ is not None:
42
+ def make_names_data(name=None):
43
+ if name is not None:
43
44
  raise InventoryError("Cannot have more than one @local")
44
45
 
45
46
  yield "@local", {}, ["@local"]
@@ -205,8 +206,7 @@ class LocalConnector(BaseConnector):
205
206
 
206
207
  return True
207
208
 
208
- @staticmethod
209
- def check_can_rsync(host):
209
+ def check_can_rsync(self):
210
210
  if not find_executable("rsync"):
211
211
  raise NotImplementedError("The `rsync` binary is not available on this system.")
212
212
 
pyinfra/connectors/ssh.py CHANGED
@@ -2,7 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  import shlex
4
4
  from distutils.spawn import find_executable
5
+ from random import uniform
5
6
  from socket import error as socket_error, gaierror
7
+ from time import sleep
6
8
  from typing import TYPE_CHECKING, Any, Iterable, Optional, Tuple
7
9
 
8
10
  import click
@@ -48,6 +50,10 @@ class ConnectorData(TypedDict):
48
50
 
49
51
  ssh_paramiko_connect_kwargs: dict
50
52
 
53
+ ssh_connect_retries: int
54
+ ssh_connect_retry_min_delay: float
55
+ ssh_connect_retry_max_delay: float
56
+
51
57
 
52
58
  connector_data_meta: dict[str, DataMeta] = {
53
59
  "ssh_hostname": DataMeta("SSH hostname"),
@@ -77,6 +83,15 @@ connector_data_meta: dict[str, DataMeta] = {
77
83
  "ssh_paramiko_connect_kwargs": DataMeta(
78
84
  "Override keyword arguments passed into Paramiko's ``SSHClient.connect``"
79
85
  ),
86
+ "ssh_connect_retries": DataMeta("Number of tries to connect via ssh", 0),
87
+ "ssh_connect_retry_min_delay": DataMeta(
88
+ "Lower bound for random delay between retries",
89
+ 0.1,
90
+ ),
91
+ "ssh_connect_retry_max_delay": DataMeta(
92
+ "Upper bound for random delay between retries",
93
+ 0.5,
94
+ ),
80
95
  }
81
96
 
82
97
 
@@ -125,8 +140,9 @@ class SSHConnector(BaseConnector):
125
140
 
126
141
  client: Optional[SSHClient] = None
127
142
 
128
- def make_names_data(hostname):
129
- yield "@ssh/{0}".format(hostname), {"ssh_hostname": hostname}, []
143
+ @staticmethod
144
+ def make_names_data(name):
145
+ yield "@ssh/{0}".format(name), {"ssh_hostname": name}, []
130
146
 
131
147
  def make_paramiko_kwargs(self) -> dict[str, Any]:
132
148
  kwargs = {
@@ -172,6 +188,29 @@ class SSHConnector(BaseConnector):
172
188
  return kwargs
173
189
 
174
190
  def connect(self) -> None:
191
+ retries = self.data["ssh_connect_retries"]
192
+
193
+ try:
194
+ while True:
195
+ try:
196
+ return self._connect()
197
+ except (SSHException, gaierror, socket_error, EOFError):
198
+ if retries == 0:
199
+ raise
200
+ retries -= 1
201
+ min_delay = self.data["ssh_connect_retry_min_delay"]
202
+ max_delay = self.data["ssh_connect_retry_max_delay"]
203
+ sleep(uniform(min_delay, max_delay))
204
+ except SSHException as e:
205
+ raise_connect_error(self.host, "SSH error", e)
206
+ except gaierror as e:
207
+ raise_connect_error(self.host, "Could not resolve hostname", e)
208
+ except socket_error as e:
209
+ raise_connect_error(self.host, "Could not connect", e)
210
+ except EOFError as e:
211
+ raise_connect_error(self.host, "EOF error", e)
212
+
213
+ def _connect(self) -> None:
175
214
  """
176
215
  Connect to a single host. Returns the SSH client if successful. Stateless by
177
216
  design so can be run in parallel.
@@ -221,18 +260,6 @@ class SSHConnector(BaseConnector):
221
260
  f"Host key for {e.hostname} does not match.",
222
261
  )
223
262
 
224
- except SSHException as e:
225
- raise_connect_error(self.host, "SSH error", e)
226
-
227
- except gaierror:
228
- raise_connect_error(self.host, "Could not resolve hostname", hostname)
229
-
230
- except socket_error as e:
231
- raise_connect_error(self.host, "Could not connect", e)
232
-
233
- except EOFError as e:
234
- raise_connect_error(self.host, "EOF error", e)
235
-
236
263
  def run_shell_command(
237
264
  self,
238
265
  command: StringCommand,
@@ -450,16 +477,10 @@ class SSHConnector(BaseConnector):
450
477
  self._put_file(filename_or_io, temp_file)
451
478
 
452
479
  # Make sure our sudo/su user can access the file
453
- if _su_user:
454
- command = StringCommand("setfacl", "-m", "u:{0}:r".format(_su_user), temp_file)
455
- elif _sudo_user:
456
- command = StringCommand("setfacl", "-m", "u:{0}:r".format(_sudo_user), temp_file)
457
- elif _doas_user:
458
- command = StringCommand("setfacl", "-m", "u:{0}:r".format(_doas_user), temp_file)
459
-
460
- if _su_user or _sudo_user or _doas_user:
480
+ other_user = _su_user or _sudo_user or _doas_user
481
+ if other_user:
461
482
  status, output = self.run_shell_command(
462
- command,
483
+ StringCommand("setfacl", "-m", f"u:{other_user}:r", temp_file),
463
484
  print_output=print_output,
464
485
  print_input=print_input,
465
486
  **arguments,
@@ -28,7 +28,8 @@ def _flatten_dict(d: dict, parent_key: str = "", sep: str = "."):
28
28
 
29
29
  class TerraformInventoryConnector(BaseConnector):
30
30
  """
31
- Generate one or more SSH hosts from a Terraform output variable. The variable must be a list of hostnames or dictionaries.
31
+ Generate one or more SSH hosts from a Terraform output variable. The variable
32
+ must be a list of hostnames or dictionaries.
32
33
 
33
34
  Output is fetched from a flattened JSON dictionary output from ``terraform output
34
35
  -json``. For example the following object:
@@ -77,21 +78,23 @@ class TerraformInventoryConnector(BaseConnector):
77
78
  """
78
79
 
79
80
  @staticmethod
80
- def make_names_data(output_key=None):
81
+ def make_names_data(name=None):
81
82
  show_warning()
82
83
 
83
- if not output_key:
84
- raise InventoryError("No Terraform output key!")
84
+ if not name:
85
+ name = ""
85
86
 
86
87
  with progress_spinner({"fetch terraform output"}):
87
88
  tf_output_raw = local.shell("terraform output -json")
88
89
 
90
+ assert isinstance(tf_output_raw, str)
89
91
  tf_output = json.loads(tf_output_raw)
90
92
  tf_output = _flatten_dict(tf_output)
91
93
 
92
- tf_output_value = tf_output.get(output_key)
94
+ tf_output_value = tf_output.get(name)
93
95
  if tf_output_value is None:
94
- raise InventoryError(f"No Terraform output with key: `{output_key}`")
96
+ keys = "\n".join(f" - {k}" for k in tf_output.keys())
97
+ raise InventoryError(f"No Terraform output with key: `{name}`, valid keys:\n{keys}")
95
98
 
96
99
  if not isinstance(tf_output_value, list):
97
100
  raise InventoryError(
@@ -40,7 +40,7 @@ def run_local_process(
40
40
  stdin=None,
41
41
  timeout: Optional[int] = None,
42
42
  print_output: bool = False,
43
- print_prefix=None,
43
+ print_prefix: str = "",
44
44
  ) -> tuple[int, "CommandOutput"]:
45
45
  process = Popen(command, shell=True, stdout=PIPE, stderr=PIPE, stdin=PIPE)
46
46
 
@@ -132,8 +132,8 @@ class VagrantInventoryConnector(BaseConnector):
132
132
  """
133
133
 
134
134
  @staticmethod
135
- def make_names_data(limit=None):
136
- vagrant_ssh_info = get_vagrant_config(limit)
135
+ def make_names_data(name=None):
136
+ vagrant_ssh_info = get_vagrant_config(name)
137
137
 
138
138
  logger.debug("Got Vagrant SSH info: \n%s", vagrant_ssh_info)
139
139
 
@@ -170,10 +170,11 @@ class VagrantInventoryConnector(BaseConnector):
170
170
  hosts.append(_make_name_data(current_host))
171
171
 
172
172
  if not hosts:
173
- if limit:
173
+ if name:
174
174
  raise InventoryError(
175
- "No running Vagrant instances matching `{0}` found!".format(limit)
175
+ "No running Vagrant instances matching `{0}` found!".format(name)
176
176
  )
177
177
  raise InventoryError("No running Vagrant instances found!")
178
178
 
179
- return hosts
179
+ for host in hosts:
180
+ yield host
pyinfra/facts/choco.py CHANGED
@@ -16,7 +16,7 @@ class ChocoPackages(FactBase):
16
16
  }
17
17
  """
18
18
 
19
- command = "choco list --local-only"
19
+ command = "choco list"
20
20
  shell_executable = "ps"
21
21
 
22
22
  default = dict
pyinfra/facts/deb.py CHANGED
@@ -48,8 +48,8 @@ class DebPackage(FactBase):
48
48
  """
49
49
 
50
50
  _regexes = {
51
- "name": r"^Package: ({0})$".format(DEB_PACKAGE_NAME_REGEX),
52
- "version": r"^Version: ({0})$".format(DEB_PACKAGE_VERSION_REGEX),
51
+ "name": r"^Package:\s+({0})$".format(DEB_PACKAGE_NAME_REGEX),
52
+ "version": r"^Version:\s+({0})$".format(DEB_PACKAGE_VERSION_REGEX),
53
53
  }
54
54
 
55
55
  requires_command = "dpkg"
@@ -0,0 +1,168 @@
1
+ from __future__ import annotations
2
+
3
+ from pyinfra.api import FactBase, MaskString, QuoteString, StringCommand
4
+ from pyinfra.api.util import try_int
5
+
6
+ from .util.databases import parse_columns_and_rows
7
+
8
+
9
+ def make_psql_command(
10
+ database=None,
11
+ user=None,
12
+ password=None,
13
+ host=None,
14
+ port=None,
15
+ executable="psql",
16
+ ):
17
+ target_bits: list[str] = []
18
+
19
+ if password:
20
+ target_bits.append(MaskString('PGPASSWORD="{0}"'.format(password)))
21
+
22
+ target_bits.append(executable)
23
+
24
+ if database:
25
+ target_bits.append("-d {0}".format(database))
26
+
27
+ if user:
28
+ target_bits.append("-U {0}".format(user))
29
+
30
+ if host:
31
+ target_bits.append("-h {0}".format(host))
32
+
33
+ if port:
34
+ target_bits.append("-p {0}".format(port))
35
+
36
+ return StringCommand(*target_bits)
37
+
38
+
39
+ def make_execute_psql_command(command, **psql_kwargs):
40
+ return StringCommand(
41
+ make_psql_command(**psql_kwargs),
42
+ "-Ac",
43
+ QuoteString(command), # quote this whole item as a single shell argument
44
+ )
45
+
46
+
47
+ class PostgresFactBase(FactBase):
48
+ abstract = True
49
+
50
+ psql_command: str
51
+ requires_command = "psql"
52
+
53
+ def command(
54
+ self,
55
+ psql_user=None,
56
+ psql_password=None,
57
+ psql_host=None,
58
+ psql_port=None,
59
+ ):
60
+ return make_execute_psql_command(
61
+ self.psql_command,
62
+ user=psql_user,
63
+ password=psql_password,
64
+ host=psql_host,
65
+ port=psql_port,
66
+ )
67
+
68
+
69
+ class PostgresRoles(PostgresFactBase):
70
+ """
71
+ Returns a dict of PostgreSQL roles and data:
72
+
73
+ .. code:: python
74
+
75
+ {
76
+ "pyinfra": {
77
+ "super": true,
78
+ "createrole": false,
79
+ "createdb": false,
80
+ ...
81
+ },
82
+ }
83
+ """
84
+
85
+ default = dict
86
+ psql_command = "SELECT * FROM pg_catalog.pg_roles"
87
+
88
+ def process(self, output):
89
+ # Remove the last line of the output (row count)
90
+ output = output[:-1]
91
+ rows = parse_columns_and_rows(
92
+ output,
93
+ "|",
94
+ # Remove the "rol" prefix on column names
95
+ remove_column_prefix="rol",
96
+ )
97
+
98
+ users = {}
99
+
100
+ for details in rows:
101
+ for key, value in list(details.items()):
102
+ if key in ("oid", "connlimit"):
103
+ details[key] = try_int(value)
104
+
105
+ if key in (
106
+ "super",
107
+ "inherit",
108
+ "createrole",
109
+ "createdb",
110
+ "canlogin",
111
+ "replication",
112
+ "bypassrls",
113
+ ):
114
+ details[key] = value == "t"
115
+
116
+ users[details.pop("name")] = details
117
+
118
+ return users
119
+
120
+
121
+ class PostgresDatabases(PostgresFactBase):
122
+ """
123
+ Returns a dict of PostgreSQL databases and metadata:
124
+
125
+ .. code:: python
126
+
127
+ {
128
+ "pyinfra_stuff": {
129
+ "encoding": "UTF8",
130
+ "collate": "en_US.UTF-8",
131
+ "ctype": "en_US.UTF-8",
132
+ ...
133
+ },
134
+ }
135
+ """
136
+
137
+ default = dict
138
+ psql_command = "SELECT pg_catalog.pg_encoding_to_char(encoding), * FROM pg_catalog.pg_database"
139
+
140
+ def process(self, output):
141
+ # Remove the last line of the output (row count)
142
+ output = output[:-1]
143
+ rows = parse_columns_and_rows(
144
+ output,
145
+ "|",
146
+ # Remove the "dat" prefix on column names
147
+ remove_column_prefix="dat",
148
+ )
149
+
150
+ databases = {}
151
+
152
+ for details in rows:
153
+ details["encoding"] = details.pop("pg_encoding_to_char")
154
+
155
+ for key, value in list(details.items()):
156
+ if key.endswith("id") or key in (
157
+ "dba",
158
+ "tablespace",
159
+ "connlimit",
160
+ ):
161
+ details[key] = try_int(value)
162
+
163
+ if key in ("istemplate", "allowconn"):
164
+ details[key] = value == "t"
165
+
166
+ databases[details.pop("name")] = details
167
+
168
+ return databases
@@ -1,168 +1,9 @@
1
- from __future__ import annotations
1
+ from .postgres import PostgresDatabases, PostgresRoles
2
2
 
3
- from pyinfra.api import FactBase, MaskString, QuoteString, StringCommand
4
- from pyinfra.api.util import try_int
5
3
 
6
- from .util.databases import parse_columns_and_rows
4
+ class PostgresqlRoles(PostgresRoles):
5
+ deprecated = True
7
6
 
8
7
 
9
- def make_psql_command(
10
- database=None,
11
- user=None,
12
- password=None,
13
- host=None,
14
- port=None,
15
- executable="psql",
16
- ):
17
- target_bits: list[str] = []
18
-
19
- if password:
20
- target_bits.append(MaskString('PGPASSWORD="{0}"'.format(password)))
21
-
22
- target_bits.append(executable)
23
-
24
- if database:
25
- target_bits.append("-d {0}".format(database))
26
-
27
- if user:
28
- target_bits.append("-U {0}".format(user))
29
-
30
- if host:
31
- target_bits.append("-h {0}".format(host))
32
-
33
- if port:
34
- target_bits.append("-p {0}".format(port))
35
-
36
- return StringCommand(*target_bits)
37
-
38
-
39
- def make_execute_psql_command(command, **psql_kwargs):
40
- return StringCommand(
41
- make_psql_command(**psql_kwargs),
42
- "-Ac",
43
- QuoteString(command), # quote this whole item as a single shell argument
44
- )
45
-
46
-
47
- class PostgresqlFactBase(FactBase):
48
- abstract = True
49
-
50
- psql_command: str
51
- requires_command = "psql"
52
-
53
- def command(
54
- self,
55
- psql_user=None,
56
- psql_password=None,
57
- psql_host=None,
58
- psql_port=None,
59
- ):
60
- return make_execute_psql_command(
61
- self.psql_command,
62
- user=psql_user,
63
- password=psql_password,
64
- host=psql_host,
65
- port=psql_port,
66
- )
67
-
68
-
69
- class PostgresqlRoles(PostgresqlFactBase):
70
- """
71
- Returns a dict of PostgreSQL roles and data:
72
-
73
- .. code:: python
74
-
75
- {
76
- "pyinfra": {
77
- "super": true,
78
- "createrole": false,
79
- "createdb": false,
80
- ...
81
- },
82
- }
83
- """
84
-
85
- default = dict
86
- psql_command = "SELECT * FROM pg_catalog.pg_roles"
87
-
88
- def process(self, output):
89
- # Remove the last line of the output (row count)
90
- output = output[:-1]
91
- rows = parse_columns_and_rows(
92
- output,
93
- "|",
94
- # Remove the "rol" prefix on column names
95
- remove_column_prefix="rol",
96
- )
97
-
98
- users = {}
99
-
100
- for details in rows:
101
- for key, value in list(details.items()):
102
- if key in ("oid", "connlimit"):
103
- details[key] = try_int(value)
104
-
105
- if key in (
106
- "super",
107
- "inherit",
108
- "createrole",
109
- "createdb",
110
- "canlogin",
111
- "replication",
112
- "bypassrls",
113
- ):
114
- details[key] = value == "t"
115
-
116
- users[details.pop("name")] = details
117
-
118
- return users
119
-
120
-
121
- class PostgresqlDatabases(PostgresqlFactBase):
122
- """
123
- Returns a dict of PostgreSQL databases and metadata:
124
-
125
- .. code:: python
126
-
127
- {
128
- "pyinfra_stuff": {
129
- "encoding": "UTF8",
130
- "collate": "en_US.UTF-8",
131
- "ctype": "en_US.UTF-8",
132
- ...
133
- },
134
- }
135
- """
136
-
137
- default = dict
138
- psql_command = "SELECT pg_catalog.pg_encoding_to_char(encoding), * FROM pg_catalog.pg_database"
139
-
140
- def process(self, output):
141
- # Remove the last line of the output (row count)
142
- output = output[:-1]
143
- rows = parse_columns_and_rows(
144
- output,
145
- "|",
146
- # Remove the "dat" prefix on column names
147
- remove_column_prefix="dat",
148
- )
149
-
150
- databases = {}
151
-
152
- for details in rows:
153
- details["encoding"] = details.pop("pg_encoding_to_char")
154
-
155
- for key, value in list(details.items()):
156
- if key.endswith("id") or key in (
157
- "dba",
158
- "tablespace",
159
- "connlimit",
160
- ):
161
- details[key] = try_int(value)
162
-
163
- if key in ("istemplate", "allowconn"):
164
- details[key] = value == "t"
165
-
166
- databases[details.pop("name")] = details
167
-
168
- return databases
8
+ class PostgresqlDatabases(PostgresDatabases):
9
+ deprecated = True
pyinfra/facts/systemd.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import re
2
- from typing import Dict
2
+ from typing import Dict, Iterable
3
3
 
4
- from pyinfra.api import FactBase
4
+ from pyinfra.api import FactBase, FactTypeError, QuoteString, StringCommand
5
5
 
6
6
  # Valid unit names consist of a "name prefix" and a dot and a suffix specifying the unit type.
7
7
  # The "unit prefix" must consist of one or more valid characters
@@ -22,18 +22,16 @@ SYSTEMD_UNIT_NAME_REGEX = (
22
22
 
23
23
  def _make_systemctl_cmd(user_mode=False, machine=None, user_name=None):
24
24
  # base command for normal and user mode
25
- systemctl_cmd = "systemctl --user" if user_mode else "systemctl"
25
+ systemctl_cmd = ["systemctl --user"] if user_mode else ["systemctl"]
26
26
 
27
27
  # add user and machine flag if given in args
28
28
  if machine is not None:
29
29
  if user_name is not None:
30
- machine_opt = "--machine={1}@{0}".format(machine, user_name)
30
+ systemctl_cmd.append("--machine={1}@{0}".format(machine, user_name))
31
31
  else:
32
- machine_opt = "--machine={0}".format(machine)
32
+ systemctl_cmd.append("--machine={0}".format(machine))
33
33
 
34
- systemctl_cmd = "{0} {1}".format(systemctl_cmd, machine_opt)
35
-
36
- return systemctl_cmd
34
+ return StringCommand(*systemctl_cmd)
37
35
 
38
36
 
39
37
  class SystemdStatus(FactBase[Dict[str, bool]]):
@@ -59,14 +57,32 @@ class SystemdStatus(FactBase[Dict[str, bool]]):
59
57
  state_key = "SubState"
60
58
  state_values = ["running", "waiting", "exited"]
61
59
 
62
- def command(self, user_mode=False, machine=None, user_name=None):
60
+ def command(self, user_mode=False, machine=None, user_name=None, services=None):
63
61
  fact_cmd = _make_systemctl_cmd(
64
62
  user_mode=user_mode,
65
63
  machine=machine,
66
64
  user_name=user_name,
67
65
  )
68
66
 
69
- return f"{fact_cmd} show --all --property Id --property {self.state_key} '*'"
67
+ if services is None:
68
+ service_strs = [QuoteString("*")]
69
+ elif isinstance(services, str):
70
+ service_strs = [QuoteString(services)]
71
+ elif isinstance(services, Iterable):
72
+ service_strs = [QuoteString(s) for s in services]
73
+ else:
74
+ raise FactTypeError(f"Invalid type passed for services argument: {type(services)}")
75
+
76
+ return StringCommand(
77
+ fact_cmd,
78
+ "show",
79
+ "--all",
80
+ "--property",
81
+ "Id",
82
+ "--property",
83
+ self.state_key,
84
+ *service_strs,
85
+ )
70
86
 
71
87
  def process(self, output) -> Dict[str, bool]:
72
88
  services: Dict[str, bool] = {}
@@ -11,6 +11,7 @@ import traceback
11
11
  from datetime import timedelta
12
12
  from fnmatch import fnmatch
13
13
  from io import StringIO
14
+ from pathlib import Path
14
15
  from typing import Union
15
16
 
16
17
  from jinja2 import TemplateRuntimeError, TemplateSyntaxError, UndefinedError
@@ -30,6 +31,7 @@ from pyinfra.api import (
30
31
  from pyinfra.api.command import make_formatted_string_command
31
32
  from pyinfra.api.util import (
32
33
  get_call_location,
34
+ get_file_io,
33
35
  get_file_sha1,
34
36
  get_path_permissions_mode,
35
37
  get_template,
@@ -569,7 +571,7 @@ def sync(
569
571
  put_files = []
570
572
  ensure_dirnames = []
571
573
  for dirpath, dirnames, filenames in os.walk(src, topdown=True):
572
- remote_dirpath = os.path.normpath(os.path.relpath(dirpath, src))
574
+ remote_dirpath = Path(os.path.normpath(os.path.relpath(dirpath, src))).as_posix()
573
575
 
574
576
  # Filter excluded dirs
575
577
  for child_dir in dirnames[:]:
@@ -999,7 +1001,7 @@ def template(src, dest, user=None, group=None, mode=None, create_remote_dir=True
999
1001
  line_number = trace_frames[-1][1]
1000
1002
 
1001
1003
  # Quickly read the line in question and one above/below for nicer debugging
1002
- with open(src, "r") as f:
1004
+ with get_file_io(src, "r") as f:
1003
1005
  template_lines = f.readlines()
1004
1006
 
1005
1007
  template_lines = [line.strip() for line in template_lines]
@@ -1012,7 +1014,7 @@ def template(src, dest, user=None, group=None, mode=None, create_remote_dir=True
1012
1014
  e,
1013
1015
  "\n".join(relevant_lines),
1014
1016
  ),
1015
- )
1017
+ ) from None
1016
1018
 
1017
1019
  output_file = StringIO(output)
1018
1020
  # Set the template attribute for nicer debugging