atlas-init 0.3.7__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atlas_init/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from pathlib import Path
2
2
 
3
- VERSION = "0.3.7"
3
+ VERSION = "0.4.0"
4
4
 
5
5
 
6
6
  def running_in_repo() -> bool:
@@ -91,3 +91,12 @@ test_suites:
91
91
  vars:
92
92
  cluster_info_m10: true
93
93
  stream_instance: true
94
+ - name: trigger
95
+ repo_go_packages:
96
+ cfn:
97
+ - cfn-resources/trigger
98
+ vars:
99
+ cluster_info_m10: true
100
+ post_apply_hooks:
101
+ - name: create_realm_resources
102
+ locate: atlas_init.cli_root.trigger.create_realm_app
atlas_init/cli.py CHANGED
@@ -40,7 +40,7 @@ from atlas_init.repos.path import (
40
40
  find_paths,
41
41
  resource_name,
42
42
  )
43
- from atlas_init.settings.config import RepoAliasNotFoundError
43
+ from atlas_init.settings.config import RepoAliasNotFoundError, TestSuite
44
44
  from atlas_init.settings.env_vars import (
45
45
  active_suites,
46
46
  init_settings,
@@ -70,10 +70,15 @@ def plan(context: typer.Context, *, skip_outputs: bool = False):
70
70
 
71
71
  @app_command()
72
72
  def apply(context: typer.Context, *, skip_outputs: bool = False):
73
- _plan_or_apply(context.args, "apply", skip_outputs=skip_outputs)
73
+ suites = _plan_or_apply(context.args, "apply", skip_outputs=skip_outputs)
74
+ for suite in suites:
75
+ for hook in suite.post_apply_hooks:
76
+ logger.info(f"running post apply hook: {hook.name}")
77
+ hook_func = locate(hook.locate)
78
+ hook_func() # type: ignore
74
79
 
75
80
 
76
- def _plan_or_apply(extra_args: list[str], command: Literal["plan", "apply"], *, skip_outputs: bool):
81
+ def _plan_or_apply(extra_args: list[str], command: Literal["plan", "apply"], *, skip_outputs: bool) -> list[TestSuite]:
77
82
  settings = init_settings()
78
83
  logger.info(f"using the '{command}' command, extra args: {extra_args}")
79
84
  try:
@@ -97,6 +102,7 @@ def _plan_or_apply(extra_args: list[str], command: Literal["plan", "apply"], *,
97
102
  if settings.env_vars_generated.exists():
98
103
  dump_vscode_dotenv(settings.env_vars_generated, settings.env_vars_vs_code)
99
104
  logger.info(f"your .env file is ready @ {settings.env_vars_vs_code}")
105
+ return suites
100
106
 
101
107
 
102
108
  @app_command()
atlas_init/cli_cfn/app.py CHANGED
@@ -2,7 +2,6 @@ import logging
2
2
  import os
3
3
 
4
4
  import typer
5
- from model_lib import parse_payload
6
5
  from zero_3rdparty.file_utils import clean_dir
7
6
 
8
7
  from atlas_init.cli_cfn.aws import (
@@ -18,8 +17,13 @@ from atlas_init.cli_cfn.aws import (
18
17
  from atlas_init.cli_cfn.cfn_parameter_finder import (
19
18
  read_execution_role,
20
19
  )
20
+ from atlas_init.cli_cfn.contract import contract_test_cmd
21
21
  from atlas_init.cli_cfn.example import example_cmd
22
- from atlas_init.cli_cfn.files import create_sample_file, has_md_link, iterate_schemas
22
+ from atlas_init.cli_cfn.files import (
23
+ create_sample_file_from_input,
24
+ has_md_link,
25
+ iterate_schemas,
26
+ )
23
27
  from atlas_init.cli_helper.run import run_command_is_ok
24
28
  from atlas_init.cloud.aws import run_in_regions
25
29
  from atlas_init.repos.cfn import (
@@ -30,6 +34,7 @@ from atlas_init.settings.env_vars import active_suites, init_settings
30
34
 
31
35
  app = typer.Typer(no_args_is_help=True)
32
36
  app.command(name="example")(example_cmd)
37
+ app.command(name="contract-test")(contract_test_cmd)
33
38
  logger = logging.getLogger(__name__)
34
39
 
35
40
 
@@ -51,7 +56,7 @@ def reg(
51
56
  deregister_cfn_resource_type(type_name, deregister=not dry_run, region_filter=region)
52
57
  logger.info(f"ready to activate {type_name}")
53
58
  settings = init_settings()
54
- cfn_execution_role = read_execution_role(settings.load_env_vars_generated())
59
+ cfn_execution_role = read_execution_role(settings.load_env_vars_full())
55
60
  last_third_party = get_last_cfn_type(type_name, region, is_third_party=True)
56
61
  assert last_third_party, f"no 3rd party extension found for {type_name} in {region}"
57
62
  if dry_run:
@@ -97,7 +102,7 @@ def inputs(
97
102
  cwd = current_dir()
98
103
  suite = suites[0]
99
104
  assert suite.cwd_is_repo_go_pkg(cwd, repo_alias="cfn")
100
- env_extra = settings.load_env_vars_generated()
105
+ env_extra = settings.load_env_vars_full()
101
106
  CREATE_FILENAME = "cfn-test-create-inputs.sh" # noqa: N806
102
107
  create_dirs = ["test/contract-testing", "test"]
103
108
  parent_dir = None
@@ -109,7 +114,7 @@ def inputs(
109
114
  assert parent_dir, f"unable to find a {CREATE_FILENAME} in {create_dirs} in {cwd}"
110
115
  if not run_command_is_ok(
111
116
  cwd=cwd,
112
- cmd=[f"./{parent_dir}/{CREATE_FILENAME}", *context.args],
117
+ cmd=f"./{parent_dir}/{CREATE_FILENAME}" + " ".join(context.args),
113
118
  env={**os.environ} | env_extra,
114
119
  logger=logger,
115
120
  ):
@@ -131,20 +136,7 @@ def inputs(
131
136
  logger.info(f"input exist at inputs/{file.name} ✅")
132
137
  if skip_samples:
133
138
  continue
134
- resource_state = parse_payload(file)
135
- assert isinstance(resource_state, dict), f"input file with not a dict {resource_state}"
136
- samples_file = samples_dir / file.name
137
- if file.name.endswith("_create.json"):
138
- create_sample_file(samples_file, log_group_name, resource_state)
139
- if file.name.endswith("_update.json"):
140
- prev_state_path = file.parent / file.name.replace("_update.json", "_create.json")
141
- prev_state: dict = parse_payload(prev_state_path) # type: ignore
142
- create_sample_file(
143
- samples_file,
144
- log_group_name,
145
- resource_state,
146
- prev_resource_state=prev_state,
147
- )
139
+ create_sample_file_from_input(samples_dir, log_group_name, file)
148
140
  if single_input:
149
141
  for file in sorted(inputs_dir.glob("*.json")):
150
142
  new_name = file.name.replace(expected_input, "inputs_1")
atlas_init/cli_cfn/aws.py CHANGED
@@ -436,7 +436,7 @@ def ensure_resource_type_activated(
436
436
  default=True,
437
437
  )
438
438
  ):
439
- assert run_command_is_ok(cmd=submit_cmd.split(), env=None, cwd=resource_path, logger=logger)
439
+ assert run_command_is_ok(cmd=submit_cmd, env=None, cwd=resource_path, logger=logger)
440
440
  cfn_type_details = get_last_cfn_type(type_name, region, is_third_party=False)
441
441
  if cfn_type_details is None:
442
442
  third_party = get_last_cfn_type(type_name, region, is_third_party=True, force_version=force_version)
@@ -0,0 +1,227 @@
1
+ import logging
2
+ import os
3
+ import re
4
+ from pathlib import Path
5
+
6
+ import typer
7
+ from model_lib import Entity
8
+ from pydantic import Field
9
+ from zero_3rdparty.file_utils import ensure_parents_write_text
10
+
11
+ from atlas_init.cli_cfn.files import create_sample_file_from_input
12
+ from atlas_init.cli_helper.run import (
13
+ run_binary_command_is_ok,
14
+ )
15
+ from atlas_init.cli_helper.run_manager import RunManager
16
+ from atlas_init.cli_root import is_dry_run
17
+ from atlas_init.repos.path import Repo, ResourcePaths, find_paths
18
+ from atlas_init.settings.env_vars import AtlasInitSettings, init_settings
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class RunContractTest(Entity):
24
+ resource_path: Path
25
+ repo_path: Path
26
+ cfn_region: str
27
+ aws_profile: str
28
+ skip_build: bool = False
29
+ dry_run: bool = Field(default_factory=is_dry_run)
30
+ only_names: list[str] | None = None
31
+
32
+ @property
33
+ def run_tests_command(self) -> tuple[str, str]:
34
+ if self.only_names:
35
+ names = " ".join(f"-k {name}" for name in self.only_names)
36
+ return (
37
+ "cfn",
38
+ f"test --function-name TestEntrypoint --verbose --region {self.cfn_region} -- {names}",
39
+ )
40
+ return (
41
+ "cfn",
42
+ f"test --function-name TestEntrypoint --verbose --region {self.cfn_region}",
43
+ )
44
+
45
+
46
+ class RunContractTestOutput(Entity):
47
+ sam_local_logs: str
48
+ sam_local_exit_code: int
49
+ contract_test_ok: bool
50
+ rpdk_log: str
51
+
52
+
53
+ class CreateContractTestInputs(Entity):
54
+ resource_path: Path
55
+ env_vars_generated: dict[str, str]
56
+ log_group_name: str
57
+
58
+
59
+ class CFNBuild(Entity):
60
+ resource_path: Path
61
+ dry_run: bool = Field(default_factory=is_dry_run)
62
+ is_debug: bool = False
63
+ tags: str = "logging callback metrics scheduler"
64
+ cgo: int = 0
65
+ goarch: str = "amd64"
66
+ goos: str = "linux"
67
+ git_sha: str = "local"
68
+ ldflags: str = "-s -w -X github.com/mongodb/mongodbatlas-cloudformation-resources/util.defaultLogLevel=info -X github.com/mongodb/mongodbatlas-cloudformation-resources/version.Version=${CFNREP_GIT_SHA}"
69
+
70
+ @property
71
+ def extra_env(self) -> dict[str, str]:
72
+ return {"GOOS": self.goos, "CGO_ENABLED": str(self.cgo), "GOARCH": self.goarch}
73
+
74
+ @property
75
+ def flags(self) -> str:
76
+ return self.ldflags.replace("${CFNREP_GIT_SHA}", self.git_sha)
77
+
78
+ @property
79
+ def command_build(self) -> str:
80
+ return f'build -ldflags="{self.flags}" -tags="{self.tags}" -o bin/bootstrap cmd/main.go'
81
+
82
+ @property
83
+ def cfn_generate(self) -> str:
84
+ return "generate"
85
+
86
+ @property
87
+ def commands(self) -> list[tuple[str, str]]:
88
+ return [
89
+ ("cfn", self.cfn_generate),
90
+ ("go", self.command_build),
91
+ ]
92
+
93
+
94
+ def contract_test_cmd(
95
+ only_names: list[str] = typer.Option(None, "-n", "--only-names", help="only run these contract tests"),
96
+ ):
97
+ result = contract_test(only_names=only_names)
98
+ if result.contract_test_ok:
99
+ logger.info("contract tests passed 🥳")
100
+ else:
101
+ logger.error("contract tests failed 💥")
102
+ logger.error(
103
+ f"function logs (exit_code={result.sam_local_exit_code}):\n {result.sam_local_logs}\n\nRPDK logs:\n{result.rpdk_log[-10_000:]}"
104
+ )
105
+ raise typer.Exit(1)
106
+ return result
107
+
108
+
109
+ def contract_test(
110
+ settings: AtlasInitSettings | None = None,
111
+ resource_paths: ResourcePaths | None = None,
112
+ only_names: list[str] | None = None,
113
+ ):
114
+ settings = settings or init_settings()
115
+ resource_paths = resource_paths or find_paths(Repo.CFN)
116
+ resource_name = resource_paths.resource_name
117
+ generated_env_vars = settings.load_env_vars_full()
118
+ create_inputs = CreateContractTestInputs(
119
+ resource_path=resource_paths.resource_path,
120
+ env_vars_generated=generated_env_vars,
121
+ log_group_name=f"mongodb-atlas-{resource_name}-logs",
122
+ )
123
+ create_response = create_contract_test_inputs(create_inputs)
124
+ create_response.log_input_files(logger)
125
+ run_contract_test = RunContractTest(
126
+ resource_path=resource_paths.resource_path,
127
+ repo_path=resource_paths.repo_path,
128
+ aws_profile=settings.AWS_PROFILE,
129
+ cfn_region=settings.cfn_region,
130
+ only_names=only_names,
131
+ )
132
+ if run_contract_test.skip_build:
133
+ logger.info("skipping build")
134
+ else:
135
+ build_event = CFNBuild(resource_path=resource_paths.resource_path)
136
+ build(build_event)
137
+ logger.info("build ok ✅")
138
+ return run_contract_tests(run_contract_test)
139
+
140
+
141
+ class CreateContractTestInputsResponse(Entity):
142
+ input_files: list[Path]
143
+ sample_files: list[Path]
144
+
145
+ def log_input_files(self, logger: logging.Logger):
146
+ inputs = self.input_files
147
+ if not inputs:
148
+ logger.warning("no input files created")
149
+ return
150
+ inputs_dir = self.input_files[0].parent
151
+ logger.info(f"{len(inputs)} inputs created in '{inputs_dir}'")
152
+ logger.info("\n".join(f"'{file.name}'" for file in self.input_files))
153
+
154
+
155
+ def create_contract_test_inputs(
156
+ event: CreateContractTestInputs,
157
+ ) -> CreateContractTestInputsResponse:
158
+ inputs_dir = event.resource_path / "inputs"
159
+ samples_dir = event.resource_path / "samples"
160
+ test_dir = event.resource_path / "test"
161
+ sample_files = []
162
+ input_files = []
163
+ for template in sorted(test_dir.glob("*.template.json")):
164
+ template_file = template.read_text()
165
+ template_file = file_replacements(template_file, event.env_vars_generated, template.name)
166
+ inputs_file = inputs_dir / template.name.replace(".template", "")
167
+ ensure_parents_write_text(inputs_file, template_file)
168
+ input_files.append(inputs_file)
169
+ sample_file = create_sample_file_from_input(samples_dir, event.log_group_name, inputs_file)
170
+ sample_files.append(sample_file)
171
+ return CreateContractTestInputsResponse(input_files=input_files, sample_files=sample_files)
172
+
173
+
174
+ def file_replacements(text: str, replacements: dict[str, str], file_name: str) -> str:
175
+ for match in re.finditer(r"\${(\w+)}", text):
176
+ var_name = match.group(1)
177
+ if var_name in replacements:
178
+ text = text.replace(match.group(0), replacements[var_name])
179
+ else:
180
+ logger.warning(f"found placeholder {match.group(0)} in {file_name} but no replacement")
181
+ return text
182
+
183
+
184
+ def build(event: CFNBuild):
185
+ for binary, command in event.commands:
186
+ is_ok = run_binary_command_is_ok(
187
+ binary,
188
+ command,
189
+ cwd=event.resource_path,
190
+ logger=logger,
191
+ dry_run=event.dry_run,
192
+ env={**os.environ, **event.extra_env},
193
+ )
194
+ if not is_ok:
195
+ logger.critical(f"failed to run {binary} {command}")
196
+ raise typer.Exit(1)
197
+
198
+
199
+ def run_contract_tests(event: RunContractTest) -> RunContractTestOutput:
200
+ with RunManager(dry_run=event.dry_run) as manager:
201
+ manager.set_timeouts(3)
202
+ resource_path = event.resource_path
203
+ run_future = manager.run_process_wait_on_log(
204
+ f"local start-lambda --skip-pull-image --region {event.cfn_region}",
205
+ binary="sam",
206
+ cwd=resource_path,
207
+ logger=logger,
208
+ line_in_log="Running on http://",
209
+ timeout=60,
210
+ )
211
+ binary, test_cmd = event.run_tests_command
212
+ test_result_ok = run_binary_command_is_ok(
213
+ binary,
214
+ test_cmd,
215
+ cwd=resource_path,
216
+ logger=logger,
217
+ dry_run=event.dry_run,
218
+ )
219
+ extra_log = resource_path / "rpdk.log"
220
+ log_content = extra_log.read_text() if extra_log.exists() else ""
221
+ sam_local_result = run_future.result(timeout=1)
222
+ return RunContractTestOutput(
223
+ sam_local_logs=sam_local_result.result_str,
224
+ sam_local_exit_code=sam_local_result.exit_code or -1,
225
+ contract_test_ok=test_result_ok,
226
+ rpdk_log=log_content,
227
+ )
@@ -53,6 +53,9 @@ class CfnExampleInputs(CfnType):
53
53
  assert self.region_filter, "region is required"
54
54
  assert self.execution_role.startswith("arn:aws:iam::"), f"invalid execution role: {self.execution_role}"
55
55
  assert self.region
56
+ if self.delete_stack_first and self.operation == Operation.UPDATE:
57
+ err_msg = "cannot delete first when updating"
58
+ raise ValueError(err_msg)
56
59
  return self
57
60
 
58
61
  @property
@@ -76,7 +79,11 @@ def example_cmd(
76
79
  operation: str = typer.Argument(...),
77
80
  example_name: str = typer.Option("", "-e", "--example-name", help="example filestem"),
78
81
  resource_params: list[str] = typer.Option(
79
- ..., "-r", "--resource-param", default_factory=list, help="key=value, can be set many times"
82
+ ...,
83
+ "-r",
84
+ "--resource-param",
85
+ default_factory=list,
86
+ help="key=value, can be set many times",
80
87
  ),
81
88
  stack_timeout_s: int = typer.Option(3600, "-t", "--stack-timeout-s"),
82
89
  delete_first: bool = typer.Option(False, "-d", "--delete-first", help="Delete existing stack first"),
@@ -93,9 +100,9 @@ def example_cmd(
93
100
  register_all_types_in_example: bool = typer.Option(False, "--reg-all", help="Check all types"),
94
101
  ):
95
102
  settings = init_settings()
96
- assert settings.cfn_config, "no cfn config found, re-run atlas_init apply with CFN flags"
103
+ assert settings.tf_vars, "no cfn config found, re-run atlas_init apply with CFN flags"
97
104
  repo_path, resource_path, _ = find_paths(Repo.CFN)
98
- env_vars_generated = settings.load_env_vars_generated()
105
+ env_vars_generated = settings.load_env_vars_full()
99
106
  inputs = CfnExampleInputs(
100
107
  type_name=type_name or infer_cfn_type_name(),
101
108
  example_name=example_name,
@@ -116,13 +123,18 @@ def example_cmd(
116
123
  example_handler(inputs, repo_path, resource_path, settings)
117
124
 
118
125
 
119
- def example_handler(inputs: CfnExampleInputs, repo_path: Path, resource_path: Path, settings: AtlasInitSettings):
126
+ def example_handler(
127
+ inputs: CfnExampleInputs,
128
+ repo_path: Path,
129
+ resource_path: Path,
130
+ settings: AtlasInitSettings,
131
+ ):
120
132
  logger.info(
121
133
  f"about to {inputs.operation} stack {inputs.stack_name} for {inputs.type_name} in {inputs.region_filter} params: {inputs.resource_params}"
122
134
  )
123
135
  type_name = inputs.type_name
124
136
  stack_name = inputs.stack_name
125
- env_vars_generated = settings.load_env_vars_generated()
137
+ env_vars_generated = settings.load_env_vars_full()
126
138
  region = inputs.region
127
139
  operation = inputs.operation
128
140
  stack_timeout_s = inputs.stack_timeout_s
@@ -5,19 +5,37 @@ from collections.abc import Iterable
5
5
  from pathlib import Path
6
6
 
7
7
  import stringcase
8
- from model_lib import Entity, dump, parse_model
8
+ from model_lib import Entity, dump, parse_model, parse_payload
9
9
  from pydantic import ConfigDict, Field, ValidationError
10
10
  from zero_3rdparty import file_utils
11
11
 
12
12
  logger = logging.getLogger(__name__)
13
13
 
14
14
 
15
+ def create_sample_file_from_input(samples_dir: Path, log_group_name: str, inputs_file: Path) -> Path:
16
+ resource_state = parse_payload(inputs_file)
17
+ assert isinstance(resource_state, dict), f"input file with not a dict {resource_state}"
18
+ samples_file = samples_dir / inputs_file.name
19
+ if inputs_file.name.endswith("_create.json"):
20
+ return create_sample_file(samples_file, log_group_name, resource_state)
21
+ if inputs_file.name.endswith("_update.json"):
22
+ prev_state_path = inputs_file.parent / inputs_file.name.replace("_update.json", "_create.json")
23
+ prev_state: dict = parse_payload(prev_state_path) # type: ignore
24
+ return create_sample_file(
25
+ samples_file,
26
+ log_group_name,
27
+ resource_state,
28
+ prev_resource_state=prev_state,
29
+ )
30
+ raise ValueError(f"unexpected input file {inputs_file}")
31
+
32
+
15
33
  def create_sample_file(
16
34
  samples_file: Path,
17
35
  log_group_name: str,
18
36
  resource_state: dict,
19
37
  prev_resource_state: dict | None = None,
20
- ):
38
+ ) -> Path:
21
39
  logger.info(f"adding sample @ {samples_file}")
22
40
  assert isinstance(resource_state, dict)
23
41
  new_json = dump(
@@ -29,6 +47,7 @@ def create_sample_file(
29
47
  "pretty_json",
30
48
  )
31
49
  file_utils.ensure_parents_write_text(samples_file, new_json)
50
+ return samples_file
32
51
 
33
52
 
34
53
  CamelAlias = ConfigDict(alias_generator=stringcase.camelcase, populate_by_name=True)
@@ -67,8 +67,9 @@ def run_go_tests(
67
67
  re_run: bool = False,
68
68
  env_vars: GoEnvVars = GoEnvVars.vscode,
69
69
  names: set[str] | None = None,
70
+ use_replay_mode: bool = False,
70
71
  ) -> GoTestResult:
71
- test_env = _resolve_env_vars(settings, env_vars)
72
+ test_env = _resolve_env_vars(settings, env_vars, use_replay_mode=use_replay_mode)
72
73
  if ci_value := test_env.pop("CI", None):
73
74
  logger.warning(f"pooped CI={ci_value}")
74
75
  results = GoTestResult()
@@ -115,14 +116,16 @@ def run_go_tests(
115
116
  )
116
117
 
117
118
 
118
- def _resolve_env_vars(settings: AtlasInitSettings, env_vars: GoEnvVars) -> dict[str, str]:
119
+ def _resolve_env_vars(settings: AtlasInitSettings, env_vars: GoEnvVars, *, use_replay_mode: bool) -> dict[str, str]:
119
120
  if env_vars == GoEnvVars.manual:
120
121
  extra_vars = settings.load_profile_manual_env_vars(skip_os_update=True)
121
122
  elif env_vars == GoEnvVars.vscode:
122
123
  extra_vars = settings.load_env_vars(settings.env_vars_vs_code)
123
124
  else:
124
125
  raise NotImplementedError(f"don't know how to load env_vars={env_vars}")
125
- test_env = os.environ | extra_vars | {"TF_ACC": "1", "TF_LOG": "DEBUG"}
126
+ mocker_env_name = "HTTP_MOCKER_REPLAY" if use_replay_mode else "HTTP_MOCKER_CAPTURE"
127
+ extra_vars |= {"TF_ACC": "1", "TF_LOG": "DEBUG", mocker_env_name: "true"}
128
+ test_env = os.environ | extra_vars
126
129
  logger.info(f"go test env-vars-extra: {sorted(extra_vars)}")
127
130
  return test_env
128
131
 
@@ -5,24 +5,27 @@ from logging import Logger
5
5
  from pathlib import Path
6
6
  from shutil import which
7
7
  from tempfile import TemporaryDirectory
8
- from typing import IO, TypeVar
8
+ from typing import IO
9
9
 
10
10
  import typer
11
11
  from zero_3rdparty.id_creator import simple_id
12
12
 
13
- StrT = TypeVar("StrT", bound=str)
13
+ LOG_CMD_PREFIX = "running: '"
14
14
 
15
15
 
16
16
  def run_command_is_ok(
17
- cmd: list[StrT],
17
+ cmd: str,
18
18
  env: dict | None,
19
19
  cwd: Path | str,
20
20
  logger: Logger,
21
21
  output: IO | None = None,
22
+ *,
23
+ dry_run: bool = False,
22
24
  ) -> bool:
23
25
  env = env or {**os.environ}
24
- command_str = " ".join(cmd)
25
- logger.info(f"running: '{command_str}' from '{cwd}'")
26
+ logger.info(f"{LOG_CMD_PREFIX}{cmd}' from '{cwd}'")
27
+ if dry_run:
28
+ return True
26
29
  output = output or sys.stdout # type: ignore
27
30
  exit_code = subprocess.call(
28
31
  cmd,
@@ -31,49 +34,44 @@ def run_command_is_ok(
31
34
  stdout=output,
32
35
  cwd=cwd,
33
36
  env=env,
37
+ shell=True, # noqa: S602 # We control the calls to this function and don't suspect any shell injection
34
38
  )
35
39
  is_ok = exit_code == 0
36
40
  if is_ok:
37
- logger.info(f"success 🥳 '{command_str}'\n") # adds extra space to separate runs
41
+ logger.info(f"success 🥳 '{cmd}'\n") # adds extra space to separate runs
38
42
  else:
39
- logger.error(f"error 💥, exit code={exit_code}, '{command_str}'")
43
+ logger.error(f"error 💥, exit code={exit_code}, '{cmd}'")
40
44
  return is_ok
41
45
 
42
46
 
43
47
  def run_binary_command_is_ok(
44
- binary_name: str, command: str, cwd: Path, logger: Logger, env: dict | None = None
48
+ binary_name: str, command: str, cwd: Path, logger: Logger, env: dict | None = None, *, dry_run: bool = False
45
49
  ) -> bool:
46
50
  env = env or {**os.environ}
47
-
48
- bin_path = find_binary_on_path(binary_name, logger)
51
+ bin_path = find_binary_on_path(binary_name, logger, allow_missing=dry_run) or binary_name
49
52
  return run_command_is_ok(
50
- [bin_path, *command.split()],
53
+ f"{bin_path} {command}",
51
54
  env=env,
52
55
  cwd=cwd,
53
56
  logger=logger,
57
+ dry_run=dry_run,
54
58
  )
55
59
 
56
60
 
57
61
  def find_binary_on_path(binary_name: str, logger: Logger, *, allow_missing: bool = False) -> str:
58
- bin_path = which(binary_name)
59
- if bin_path:
62
+ if bin_path := which(binary_name):
60
63
  return bin_path
61
64
  if allow_missing:
65
+ logger.warning(f"binary '{binary_name}' not found on $PATH")
62
66
  return ""
63
67
  logger.critical(f"please install '{binary_name}'")
64
68
  raise typer.Exit(1)
65
69
 
66
70
 
67
71
  def run_command_exit_on_failure(
68
- cmd: list[StrT] | str,
69
- cwd: Path | str,
70
- logger: Logger,
71
- env: dict | None = None,
72
+ cmd: str, cwd: Path | str, logger: Logger, env: dict | None = None, *, dry_run: bool = False
72
73
  ) -> None:
73
- if isinstance(cmd, str):
74
- cmd = cmd.split() # type: ignore
75
- assert isinstance(cmd, list)
76
- if not run_command_is_ok(cmd, cwd=cwd, env=env, logger=logger):
74
+ if not run_command_is_ok(cmd, cwd=cwd, env=env, logger=logger, dry_run=dry_run):
77
75
  logger.critical("command failed, see output 👆")
78
76
  raise typer.Exit(1)
79
77
 
@@ -84,7 +82,7 @@ def run_command_receive_result(
84
82
  with TemporaryDirectory() as temp_dir:
85
83
  result_file = Path(temp_dir) / "file"
86
84
  with open(result_file, "w") as file:
87
- is_ok = run_command_is_ok(command.split(), env=env, cwd=cwd, logger=logger, output=file)
85
+ is_ok = run_command_is_ok(command, env=env, cwd=cwd, logger=logger, output=file)
88
86
  output_text = result_file.read_text().strip()
89
87
  if not is_ok:
90
88
  if can_fail:
@@ -99,7 +97,7 @@ def run_command_is_ok_output(command: str, cwd: Path, logger: Logger, env: dict
99
97
  with TemporaryDirectory() as temp_dir:
100
98
  result_file = Path(temp_dir) / f"{simple_id()}.txt"
101
99
  with open(result_file, "w") as file:
102
- is_ok = run_command_is_ok(command.split(), env=env, cwd=cwd, logger=logger, output=file)
100
+ is_ok = run_command_is_ok(command, env=env, cwd=cwd, logger=logger, output=file)
103
101
  output_text = result_file.read_text().strip()
104
102
  return is_ok, output_text
105
103