aws-annoying 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. aws_annoying/cli/__init__.py +0 -0
  2. aws_annoying/cli/load_variables.py +139 -0
  3. aws_annoying/{main.py → cli/main.py} +4 -4
  4. aws_annoying/{mfa → cli/mfa}/_app.py +1 -1
  5. aws_annoying/{mfa → cli/mfa}/configure.py +4 -45
  6. aws_annoying/cli/session_manager/__init__.py +3 -0
  7. aws_annoying/{session_manager → cli/session_manager}/_app.py +1 -1
  8. aws_annoying/cli/session_manager/_common.py +54 -0
  9. aws_annoying/{session_manager → cli/session_manager}/install.py +2 -2
  10. aws_annoying/{session_manager → cli/session_manager}/port_forward.py +26 -33
  11. aws_annoying/cli/session_manager/start.py +47 -0
  12. aws_annoying/mfa.py +54 -0
  13. aws_annoying/session_manager/__init__.py +10 -2
  14. aws_annoying/session_manager/session_manager.py +24 -35
  15. aws_annoying/session_manager/shortcuts.py +72 -0
  16. aws_annoying/variables.py +133 -0
  17. {aws_annoying-0.3.0.dist-info → aws_annoying-0.5.0.dist-info}/METADATA +6 -6
  18. aws_annoying-0.5.0.dist-info/RECORD +31 -0
  19. aws_annoying-0.5.0.dist-info/entry_points.txt +2 -0
  20. aws_annoying/load_variables.py +0 -254
  21. aws_annoying/session_manager/_common.py +0 -24
  22. aws_annoying/session_manager/start.py +0 -9
  23. aws_annoying-0.3.0.dist-info/RECORD +0 -26
  24. aws_annoying-0.3.0.dist-info/entry_points.txt +0 -2
  25. /aws_annoying/{app.py → cli/app.py} +0 -0
  26. /aws_annoying/{ecs_task_definition_lifecycle.py → cli/ecs_task_definition_lifecycle.py} +0 -0
  27. /aws_annoying/{mfa → cli/mfa}/__init__.py +0 -0
  28. /aws_annoying/{session_manager → cli/session_manager}/stop.py +0 -0
  29. {aws_annoying-0.3.0.dist-info → aws_annoying-0.5.0.dist-info}/WHEEL +0 -0
  30. {aws_annoying-0.3.0.dist-info → aws_annoying-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -28,15 +28,13 @@ logger = logging.getLogger(__name__)
28
28
  class SessionManager:
29
29
  """AWS Session Manager plugin manager."""
30
30
 
31
- def __init__(self, *, session: boto3.session.Session | None = None, downloader: AbstractDownloader) -> None:
31
+ def __init__(self, *, session: boto3.session.Session | None = None) -> None:
32
32
  """Initialize SessionManager.
33
33
 
34
34
  Args:
35
35
  session: Boto3 session to use for AWS operations.
36
- downloader: File downloader to use for downloading the plugin.
37
36
  """
38
37
  self.session = session or boto3.session.Session()
39
- self.downloader = downloader
40
38
 
41
39
  # ------------------------------------------------------------------------
42
40
  # Installation
@@ -90,6 +88,7 @@ class SessionManager:
90
88
  linux_distribution: _LinuxDistribution | None = None,
91
89
  arch: str | None = None,
92
90
  root: bool | None = None,
91
+ downloader: AbstractDownloader,
93
92
  ) -> None:
94
93
  """Install AWS Session Manager plugin.
95
94
 
@@ -99,17 +98,23 @@ class SessionManager:
99
98
  If `None` and current `os` is `"Linux"`, will try to detect the distribution from current system.
100
99
  arch: The architecture to install the plugin on. If `None`, will use the current architecture.
101
100
  root: Whether to run the installation as root. If `None`, will check if the current user is root.
101
+ downloader: File downloader to use for downloading the plugin.
102
102
  """
103
103
  os = os or platform.system()
104
104
  arch = arch or platform.machine()
105
105
 
106
106
  if os == "Windows":
107
- self._install_windows()
107
+ self._install_windows(downloader=downloader)
108
108
  elif os == "Darwin":
109
- self._install_macos(arch=arch, root=root or is_root())
109
+ self._install_macos(arch=arch, root=root or is_root(), downloader=downloader)
110
110
  elif os == "Linux":
111
111
  linux_distribution = linux_distribution or _detect_linux_distribution()
112
- self._install_linux(linux_distribution=linux_distribution, arch=arch, root=root or is_root())
112
+ self._install_linux(
113
+ linux_distribution=linux_distribution,
114
+ arch=arch,
115
+ root=root or is_root(),
116
+ downloader=downloader,
117
+ )
113
118
  else:
114
119
  msg = f"Unsupported operating system: {os}"
115
120
  raise UnsupportedPlatformError(msg)
@@ -118,20 +123,20 @@ class SessionManager:
118
123
  """Hook to run before invoking plugin installation command."""
119
124
 
120
125
  # https://docs.aws.amazon.com/systems-manager/latest/userguide/install-plugin-windows.html
121
- def _install_windows(self) -> None:
126
+ def _install_windows(self, *, downloader: AbstractDownloader) -> None:
122
127
  """Install session-manager-plugin on Windows via EXE installer."""
123
128
  download_url = (
124
129
  "https://s3.amazonaws.com/session-manager-downloads/plugin/latest/windows/SessionManagerPluginSetup.exe"
125
130
  )
126
131
  with tempfile.TemporaryDirectory() as temp_dir:
127
132
  p = Path(temp_dir)
128
- exe_installer = self.downloader.download(download_url, to=p / "SessionManagerPluginSetup.exe")
133
+ exe_installer = downloader.download(download_url, to=p / "SessionManagerPluginSetup.exe")
129
134
  command = [str(exe_installer), "/quiet"]
130
135
  self.before_install(command)
131
136
  subprocess.call(command, cwd=p) # noqa: S603
132
137
 
133
138
  # https://docs.aws.amazon.com/systems-manager/latest/userguide/install-plugin-macos-overview.html
134
- def _install_macos(self, *, arch: str, root: bool) -> None:
139
+ def _install_macos(self, *, arch: str, root: bool, downloader: AbstractDownloader) -> None:
135
140
  """Install session-manager-plugin on macOS via signed installer."""
136
141
  # ! Intel chip will not be supported
137
142
  if arch == "x86_64":
@@ -148,7 +153,7 @@ class SessionManager:
148
153
 
149
154
  with tempfile.TemporaryDirectory() as temp_dir:
150
155
  p = Path(temp_dir)
151
- pkg_installer = self.downloader.download(download_url, to=p / "session-manager-plugin.pkg")
156
+ pkg_installer = downloader.download(download_url, to=p / "session-manager-plugin.pkg")
152
157
 
153
158
  # Run installer
154
159
  command = command_as_root(
@@ -179,6 +184,7 @@ class SessionManager:
179
184
  linux_distribution: _LinuxDistribution,
180
185
  arch: str,
181
186
  root: bool,
187
+ downloader: AbstractDownloader,
182
188
  ) -> None:
183
189
  name = linux_distribution.name
184
190
  version = linux_distribution.version
@@ -200,7 +206,7 @@ class SessionManager:
200
206
 
201
207
  with tempfile.TemporaryDirectory() as temp_dir:
202
208
  p = Path(temp_dir)
203
- deb_installer = self.downloader.download(download_url, to=p / "session-manager-plugin.deb")
209
+ deb_installer = downloader.download(download_url, to=p / "session-manager-plugin.deb")
204
210
 
205
211
  # Invoke installation command
206
212
  command = command_as_root(["dpkg", "--install", str(deb_installer)], root=root)
@@ -241,28 +247,25 @@ class SessionManager:
241
247
  raise UnsupportedPlatformError(msg)
242
248
 
243
249
  # ------------------------------------------------------------------------
244
- # Session
250
+ # Command
245
251
  # ------------------------------------------------------------------------
246
- def start(
252
+ def build_command(
247
253
  self,
248
- *,
249
254
  target: str,
250
255
  document_name: str,
251
256
  parameters: dict[str, Any],
252
257
  reason: str | None = None,
253
- log_file: Path | None = None,
254
- ) -> subprocess.Popen:
255
- """Start new session.
258
+ ) -> list[str]:
259
+ """Build command for starting a session.
256
260
 
257
261
  Args:
258
- target: The target instance ID or name.
262
+ target: The target instance ID.
259
263
  document_name: The SSM document name to use for the session.
260
264
  parameters: The parameters to pass to the SSM document.
261
265
  reason: The reason for starting the session.
262
- log_file: Optional file to log output to.
263
266
 
264
267
  Returns:
265
- Process ID of the session.
268
+ The command to start the session.
266
269
  """
267
270
  is_installed, binary_path, version = self.verify_installation()
268
271
  if not is_installed:
@@ -279,7 +282,7 @@ class SessionManager:
279
282
  )
280
283
 
281
284
  region = self.session.region_name
282
- command = [
285
+ return [
283
286
  str(binary_path),
284
287
  json.dumps(response),
285
288
  region,
@@ -289,20 +292,6 @@ class SessionManager:
289
292
  f"https://ssm.{region}.amazonaws.com",
290
293
  ]
291
294
 
292
- stdout: subprocess._FILE
293
- if log_file is not None: # noqa: SIM108
294
- stdout = log_file.open(mode="at+", buffering=1)
295
- else:
296
- stdout = subprocess.DEVNULL
297
-
298
- return subprocess.Popen( # noqa: S603
299
- command,
300
- stdout=stdout,
301
- stderr=subprocess.STDOUT,
302
- text=True,
303
- close_fds=False, # FD inherited from parent process
304
- )
305
-
306
295
 
307
296
  # ? Could be moved to utils, but didn't because it's too specific to this module
308
297
  class _LinuxDistribution(NamedTuple):
@@ -0,0 +1,72 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import subprocess
5
+ from contextlib import contextmanager
6
+ from typing import TYPE_CHECKING
7
+
8
+ from .session_manager import SessionManager
9
+
10
+ if TYPE_CHECKING:
11
+ from collections.abc import Iterator
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @contextmanager
18
+ def port_forward(
19
+ *,
20
+ through: str,
21
+ local_port: int,
22
+ remote_host: str,
23
+ remote_port: int,
24
+ reason: str | None = None,
25
+ ) -> Iterator[subprocess.Popen[str]]:
26
+ """Context manager for port forwarding sessions.
27
+
28
+ Args:
29
+ through: The instance ID to use as port-forwarding proxy.
30
+ local_port: The local port to listen to.
31
+ remote_host: The remote host to connect to.
32
+ remote_port: The remote port to connect to.
33
+ reason: The reason for starting the session.
34
+
35
+ Returns:
36
+ The command to start the session.
37
+ """
38
+ session_manager = SessionManager()
39
+ command = session_manager.build_command(
40
+ target=through,
41
+ document_name="AWS-StartPortForwardingSessionToRemoteHost",
42
+ parameters={
43
+ "localPortNumber": [str(local_port)],
44
+ "host": [remote_host],
45
+ "portNumber": [str(remote_port)],
46
+ },
47
+ reason=reason,
48
+ )
49
+ try:
50
+ proc = subprocess.Popen( # noqa: S603
51
+ command,
52
+ stdout=subprocess.PIPE,
53
+ stderr=subprocess.STDOUT,
54
+ text=True,
55
+ )
56
+
57
+ # * Must be unreachable
58
+ if proc.stdout is None:
59
+ msg = "Standard output is not available"
60
+ raise RuntimeError(msg)
61
+
62
+ # Wait for the session to start
63
+ # ? Not sure this is trustworthy health check
64
+ # TODO(lasuillard): Need timeout to avoid hanging forever
65
+ for line in proc.stdout:
66
+ if "Waiting for connections..." in line:
67
+ logger.info("Session started successfully.")
68
+ break
69
+
70
+ yield proc
71
+ finally:
72
+ proc.terminate()
@@ -0,0 +1,133 @@
1
+ # flake8: noqa: B008
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ from typing import Any, TypedDict
6
+
7
+ import boto3
8
+
9
+ # Type aliases for readability
10
+ _ARN = str
11
+ _Variables = dict[str, Any]
12
+
13
+ # TODO(lasuillard): Need some refactoring (with #2, #3)
14
+ # TODO(lasuillard): Put some logging
15
+
16
+
17
+ class _LoadStatsDict(TypedDict):
18
+ secrets: int
19
+ parameters: int
20
+
21
+
22
+ class VariableLoader: # noqa: D101
23
+ def __init__(self, *, dry_run: bool) -> None:
24
+ """Initialize the VariableLoader.
25
+
26
+ Args:
27
+ dry_run: Whether to run in dry-run mode.
28
+ console: Rich console instance.
29
+ """
30
+ self.dry_run = dry_run
31
+
32
+ # TODO(lasuillard): Currently not using pagination (do we need more than 10-20 secrets or parameters each?)
33
+ # ; consider adding it if needed
34
+ def load(self, map_arns: dict[str, _ARN]) -> tuple[dict[str, Any], _LoadStatsDict]:
35
+ """Load the variables from the AWS Secrets Manager and SSM Parameter Store.
36
+
37
+ Each secret or parameter should be a valid dictionary, where the keys are the variable names
38
+ and the values are the variable values.
39
+
40
+ The items are merged in the order of the key of provided mapping, overwriting the variables with the same name
41
+ in the order of the keys.
42
+ """
43
+ # Split the ARNs by resource types
44
+ secrets_map, parameters_map = {}, {}
45
+ for idx, arn in map_arns.items():
46
+ if arn.startswith("arn:aws:secretsmanager:"):
47
+ secrets_map[idx] = arn
48
+ elif arn.startswith("arn:aws:ssm:"):
49
+ parameters_map[idx] = arn
50
+ else:
51
+ msg = f"Unsupported resource: {arn!r}"
52
+ raise ValueError(msg)
53
+
54
+ # Retrieve variables from AWS resources
55
+ secrets: dict[str, _Variables]
56
+ parameters: dict[str, _Variables]
57
+ if self.dry_run:
58
+ secrets = {idx: {} for idx, _ in secrets_map.items()}
59
+ parameters = {idx: {} for idx, _ in parameters_map.items()}
60
+ else:
61
+ secrets = self._retrieve_secrets(secrets_map)
62
+ parameters = self._retrieve_parameters(parameters_map)
63
+
64
+ load_stats: _LoadStatsDict = {
65
+ "secrets": len(secrets),
66
+ "parameters": len(parameters),
67
+ }
68
+
69
+ # Merge the variables in order
70
+ full_variables = secrets | parameters # Keys MUST NOT conflict
71
+ merged_in_order = {}
72
+ for _, variables in sorted(full_variables.items()):
73
+ merged_in_order.update(variables)
74
+
75
+ return merged_in_order, load_stats
76
+
77
+ def _retrieve_secrets(self, secrets_map: dict[str, _ARN]) -> dict[str, _Variables]:
78
+ """Retrieve the secrets from AWS Secrets Manager."""
79
+ if not secrets_map:
80
+ return {}
81
+
82
+ secretsmanager = boto3.client("secretsmanager")
83
+
84
+ # Retrieve the secrets
85
+ arns = list(secrets_map.values())
86
+ response = secretsmanager.batch_get_secret_value(SecretIdList=arns)
87
+ if errors := response["Errors"]:
88
+ msg = f"Failed to retrieve secrets: {errors!r}"
89
+ raise ValueError(msg)
90
+
91
+ # Parse the secrets
92
+ secrets = response["SecretValues"]
93
+ result = {}
94
+ for secret in secrets:
95
+ arn = secret["ARN"]
96
+ order_key = next(key for key, value in secrets_map.items() if value == arn)
97
+ data = json.loads(secret["SecretString"])
98
+ if not isinstance(data, dict):
99
+ msg = f"Secret data must be a valid dictionary, but got: {type(data)!r}"
100
+ raise TypeError(msg)
101
+
102
+ result[order_key] = data
103
+
104
+ return result
105
+
106
+ def _retrieve_parameters(self, parameters_map: dict[str, _ARN]) -> dict[str, _Variables]:
107
+ """Retrieve the parameters from AWS SSM Parameter Store."""
108
+ if not parameters_map:
109
+ return {}
110
+
111
+ ssm = boto3.client("ssm")
112
+
113
+ # Retrieve the parameters
114
+ parameter_names = list(parameters_map.values())
115
+ response = ssm.get_parameters(Names=parameter_names, WithDecryption=True)
116
+ if errors := response["InvalidParameters"]:
117
+ msg = f"Failed to retrieve parameters: {errors!r}"
118
+ raise ValueError(msg)
119
+
120
+ # Parse the parameters
121
+ parameters = response["Parameters"]
122
+ result = {}
123
+ for parameter in parameters:
124
+ arn = parameter["ARN"]
125
+ order_key = next(key for key, value in parameters_map.items() if value == arn)
126
+ data = json.loads(parameter["Value"])
127
+ if not isinstance(data, dict):
128
+ msg = f"Parameter data must be a valid dictionary, but got: {type(data)!r}"
129
+ raise TypeError(msg)
130
+
131
+ result[order_key] = data
132
+
133
+ return result
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aws-annoying
3
- Version: 0.3.0
3
+ Version: 0.5.0
4
4
  Summary: Utils to handle some annoying AWS tasks.
5
5
  Project-URL: Homepage, https://github.com/lasuillard/aws-annoying
6
6
  Project-URL: Repository, https://github.com/lasuillard/aws-annoying.git
@@ -9,11 +9,11 @@ Author-email: Yuchan Lee <lasuillard@gmail.com>
9
9
  License-Expression: MIT
10
10
  License-File: LICENSE
11
11
  Requires-Python: <4.0,>=3.9
12
- Requires-Dist: boto3>=1.37.1
13
- Requires-Dist: pydantic>=2.10.6
14
- Requires-Dist: requests>=2.32.3
15
- Requires-Dist: tqdm>=4.67.1
16
- Requires-Dist: typer>=0.15.1
12
+ Requires-Dist: boto3<2,>=1
13
+ Requires-Dist: pydantic<3,>=2
14
+ Requires-Dist: requests<3,>=2
15
+ Requires-Dist: tqdm<5,>=4
16
+ Requires-Dist: typer<1,>=0
17
17
  Provides-Extra: dev
18
18
  Requires-Dist: boto3-stubs[ec2,ecs,secretsmanager,ssm,sts]>=1.37.1; extra == 'dev'
19
19
  Requires-Dist: mypy~=1.15.0; extra == 'dev'
@@ -0,0 +1,31 @@
1
+ aws_annoying/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ aws_annoying/mfa.py,sha256=m6-V1bWeUWsAmRddl-lv13mPCMnftoPzJoNnZ0kiaWQ,2007
3
+ aws_annoying/variables.py,sha256=a9cMS9JU-XA2h1tztO7ofixoDEpqtS_eVEiWrQ75mTo,4761
4
+ aws_annoying/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ aws_annoying/cli/app.py,sha256=sp50uVoAl4D6Wk3DFpzKZzSsxmSxNYejFxm62b_Kxps,201
6
+ aws_annoying/cli/ecs_task_definition_lifecycle.py,sha256=O36Bf5LBnVJNyYmdlUxhtsIHNoxky1t5YacAXiL9UEI,2803
7
+ aws_annoying/cli/load_variables.py,sha256=eWNByUEc1ijF8uCe_egdAnjWxfMNCZeVr0vtTtQLe3Y,5086
8
+ aws_annoying/cli/main.py,sha256=TSzPeMkgIgKFf3bom_vDkFYK0bHF1r5K9ADreZUV3k4,503
9
+ aws_annoying/cli/mfa/__init__.py,sha256=rbEGhw5lOQZV_XAc3nSbo56JVhsSPpeOgEtiAy9qzEA,50
10
+ aws_annoying/cli/mfa/_app.py,sha256=Ub7gxb6kGF3Ve1ucQSOjHmc4jAu8mxgegcXsIbOzLLQ,189
11
+ aws_annoying/cli/mfa/configure.py,sha256=vsoHfTVFF2dPgiYsp2L-EkMwtAA0_-tVwFd6Wv6DscU,3746
12
+ aws_annoying/cli/session_manager/__init__.py,sha256=FkT6jT6OXduOURN61d-U6hgd-XluQbvuVtKXXiXgSEk,105
13
+ aws_annoying/cli/session_manager/_app.py,sha256=OVOHW0iyKzunvaqLhjoseHw1-WxJ1gGb7QmiyAEezyY,221
14
+ aws_annoying/cli/session_manager/_common.py,sha256=u23F4mJOHWHphLYL1gOAh8J3a_Odyk4VVvI175KWmzg,1616
15
+ aws_annoying/cli/session_manager/install.py,sha256=zcQi91xVFKhbSOD4VBc6YG9-fDnhymVyK63PlITdxug,1445
16
+ aws_annoying/cli/session_manager/port_forward.py,sha256=J8_CIrTsbcOYRYOdHkdz81dkaNWxI_l70E8gyJ-Ukh8,4192
17
+ aws_annoying/cli/session_manager/start.py,sha256=pPS0jKuURGTX-WTix3owqqisX-bmCydqfyqC0kFGnt8,1358
18
+ aws_annoying/cli/session_manager/stop.py,sha256=ttU6nlbVgBkZDtY-DwUyCstv5TFtat5TljkyuY8QICU,1482
19
+ aws_annoying/session_manager/__init__.py,sha256=IENviL3ux2LF7o9xFGYEiqaGw03hxnyNX2btbB1xyEU,318
20
+ aws_annoying/session_manager/errors.py,sha256=YioKlRtZ-GUP0F_ts_ebw7-HYkxe8mTes6HK821Kuiw,353
21
+ aws_annoying/session_manager/session_manager.py,sha256=myZxY_WE4akdlTsH1mOvf0Ublwg-hf1vEkEcmdZyYSU,12147
22
+ aws_annoying/session_manager/shortcuts.py,sha256=uFRPGia_5gqfBDxwOjmLg7UFzhvkSFUqopWuzN5_kbA,1973
23
+ aws_annoying/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
+ aws_annoying/utils/debugger.py,sha256=UFllDCGI2gPtwo1XS5vqw0qyR6bYr7XknmBwSxalKIc,754
25
+ aws_annoying/utils/downloader.py,sha256=aB5RzT-LpbFX24-2HXlAkdgVowc4TR9FWT_K8WwZ1BE,1923
26
+ aws_annoying/utils/platform.py,sha256=h3DUWmTMM-_4TfTWNqY0uNqyVsBjAuMm2DEbG-daxe8,742
27
+ aws_annoying-0.5.0.dist-info/METADATA,sha256=kHaGAHqfkZ8Ip4NzcLbE7Bh01oA1CvoTsUa3Ljx1ctU,1916
28
+ aws_annoying-0.5.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
29
+ aws_annoying-0.5.0.dist-info/entry_points.txt,sha256=DcKE5V0WvVJ8wUOHxyUz1yLAJOuuJUgRPlMcQ4O7jEs,66
30
+ aws_annoying-0.5.0.dist-info/licenses/LICENSE,sha256=Q5GkvYijQ2KTQ-QWhv43ilzCno4ZrzrEuATEQZd9rYo,1067
31
+ aws_annoying-0.5.0.dist-info/RECORD,,
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ aws-annoying = aws_annoying.cli.main:entrypoint
@@ -1,254 +0,0 @@
1
- # flake8: noqa: B008
2
- from __future__ import annotations
3
-
4
- import json
5
- import os
6
- import subprocess
7
- from typing import Any, NoReturn, Optional
8
-
9
- import boto3
10
- import typer
11
- from rich.console import Console
12
- from rich.table import Table
13
-
14
- from .app import app
15
-
16
-
17
- @app.command(
18
- context_settings={
19
- # Allow extra arguments for user provided command
20
- "allow_extra_args": True,
21
- "ignore_unknown_options": True,
22
- },
23
- )
24
- def load_variables( # noqa: PLR0913
25
- *,
26
- ctx: typer.Context,
27
- arns: list[str] = typer.Option(
28
- [],
29
- metavar="ARN",
30
- help=(
31
- "ARNs of the secret or parameter to load."
32
- " The variables are loaded in the order of the ARNs,"
33
- " overwriting the variables with the same name in the order of the ARNs."
34
- ),
35
- ),
36
- env_prefix: Optional[str] = typer.Option(
37
- None,
38
- help="Prefix of the environment variables to load the ARNs from.",
39
- show_default=False,
40
- ),
41
- overwrite_env: bool = typer.Option(
42
- False, # noqa: FBT003
43
- help="Overwrite the existing environment variables with the same name.",
44
- ),
45
- quiet: bool = typer.Option(
46
- False, # noqa: FBT003
47
- help="Suppress all outputs from this command.",
48
- ),
49
- dry_run: bool = typer.Option(
50
- False, # noqa: FBT003
51
- help="Print the progress only. Neither load variables nor run the command.",
52
- ),
53
- replace: bool = typer.Option(
54
- True, # noqa: FBT003
55
- help=(
56
- "Replace the current process (`os.execvpe`) with the command."
57
- " If disabled, run the command as a `subprocess`."
58
- ),
59
- ),
60
- ) -> NoReturn:
61
- """Wrapper command to run command with variables from AWS resources injected as environment variables.
62
-
63
- This script is intended to be used in the ECS environment, where currently AWS does not support
64
- injecting whole JSON dictionary of secrets or parameters as environment variables directly.
65
-
66
- It first loads the variables from the AWS sources then runs the command with the variables injected as environment variables.
67
-
68
- In addition to `--arns` option, you can provide ARNs as the environment variables by providing `--env-prefix`.
69
- For example, if you have the following environment variables:
70
-
71
- ```shell
72
- export LOAD_AWS_CONFIG__001_app_config=arn:aws:secretsmanager:...
73
- export LOAD_AWS_CONFIG__002_db_config=arn:aws:ssm:...
74
- ```
75
-
76
- You can run the following command:
77
-
78
- ```shell
79
- aws-annoying load-variables --env-prefix LOAD_AWS_CONFIG__ -- ...
80
- ```
81
-
82
- The variables are loaded in the order of option provided, overwriting the variables with the same name in the order of the ARNs.
83
- Existing environment variables are preserved by default, unless `--overwrite-env` is provided.
84
- """ # noqa: E501
85
- console = Console(quiet=quiet, emoji=False)
86
-
87
- command = ctx.args
88
- if not command:
89
- console.print("⚠️ No command provided. Exiting...")
90
- raise typer.Exit(0)
91
-
92
- # Mapping of the ARNs by index (index used for ordering)
93
- map_arns_by_index = {str(idx): arn for idx, arn in enumerate(arns)}
94
- if env_prefix:
95
- console.print(f"🔍 Loading ARNs from environment variables with prefix: {env_prefix!r}")
96
- arns_env = {
97
- key.removeprefix(env_prefix): value for key, value in os.environ.items() if key.startswith(env_prefix)
98
- }
99
- console.print(f"🔍 Found {len(arns_env)} sources from environment variables.")
100
- map_arns_by_index = arns_env | map_arns_by_index
101
-
102
- # Briefly show the ARNs
103
- table = Table("Index", "ARN")
104
- for idx, arn in sorted(map_arns_by_index.items()):
105
- table.add_row(idx, arn)
106
-
107
- console.print(table)
108
-
109
- # Retrieve the variables
110
- loader = VariableLoader(dry_run=dry_run, console=console)
111
- try:
112
- variables = loader.load(map_arns_by_index)
113
- except Exception as exc: # noqa: BLE001
114
- console.print(f"❌ Failed to load the variables: {exc!s}")
115
- raise typer.Exit(1) from None
116
-
117
- # Prepare the environment variables
118
- env = os.environ.copy()
119
- if overwrite_env:
120
- env.update(variables)
121
- else:
122
- # Update variables, preserving the existing ones
123
- for key, value in variables.items():
124
- env.setdefault(key, str(value))
125
-
126
- # Run the command with the variables injected as environment variables, replacing current process
127
- console.print(f"🚀 Running the command: [bold orchid]{' '.join(command)}[/bold orchid]")
128
- if replace: # pragma: no cover (not coverable)
129
- os.execvpe(command[0], command, env=env) # noqa: S606
130
- # The above line should never return
131
-
132
- result = subprocess.run(command, env=env, check=False) # noqa: S603
133
- raise typer.Exit(result.returncode)
134
-
135
-
136
- # Type aliases for readability
137
- _ARN = str
138
- _Variables = dict[str, Any]
139
-
140
-
141
- class VariableLoader: # noqa: D101
142
- def __init__(self, *, console: Console | None = None, dry_run: bool) -> None:
143
- """Initialize the VariableLoader.
144
-
145
- Args:
146
- dry_run: Whether to run in dry-run mode.
147
- console: Rich console instance.
148
- """
149
- self.console = console or Console(quiet=True)
150
- self.dry_run = dry_run
151
-
152
- # TODO(lasuillard): Currently not using pagination (do we need more than 10-20 secrets or parameters each?)
153
- # ; consider adding it if needed
154
- def load(self, map_arns: dict[str, _ARN]) -> dict[str, Any]:
155
- """Load the variables from the AWS Secrets Manager and SSM Parameter Store.
156
-
157
- Each secret or parameter should be a valid dictionary, where the keys are the variable names
158
- and the values are the variable values.
159
-
160
- The items are merged in the order of the key of provided mapping, overwriting the variables with the same name
161
- in the order of the keys.
162
- """
163
- self.console.print("🔍 Retrieving variables from AWS resources...")
164
- if self.dry_run:
165
- self.console.print("⚠️ Dry run mode enabled. Variables won't be loaded from AWS.")
166
-
167
- # Split the ARNs by resource types
168
- secrets_map, parameters_map = {}, {}
169
- for idx, arn in map_arns.items():
170
- if arn.startswith("arn:aws:secretsmanager:"):
171
- secrets_map[idx] = arn
172
- elif arn.startswith("arn:aws:ssm:"):
173
- parameters_map[idx] = arn
174
- else:
175
- msg = f"Unsupported resource: {arn!r}"
176
- raise ValueError(msg)
177
-
178
- # Retrieve variables from AWS resources
179
- secrets: dict[str, _Variables]
180
- parameters: dict[str, _Variables]
181
- if self.dry_run:
182
- secrets = {idx: {} for idx, _ in secrets_map.items()}
183
- parameters = {idx: {} for idx, _ in parameters_map.items()}
184
- else:
185
- secrets = self._retrieve_secrets(secrets_map)
186
- parameters = self._retrieve_parameters(parameters_map)
187
-
188
- self.console.print(f"✅ Retrieved {len(secrets)} secrets and {len(parameters)} parameters.")
189
-
190
- # Merge the variables in order
191
- full_variables = secrets | parameters # Keys MUST NOT conflict
192
- merged_in_order = {}
193
- for _, variables in sorted(full_variables.items()):
194
- merged_in_order.update(variables)
195
-
196
- return merged_in_order
197
-
198
- def _retrieve_secrets(self, secrets_map: dict[str, _ARN]) -> dict[str, _Variables]:
199
- """Retrieve the secrets from AWS Secrets Manager."""
200
- if not secrets_map:
201
- return {}
202
-
203
- secretsmanager = boto3.client("secretsmanager")
204
-
205
- # Retrieve the secrets
206
- arns = list(secrets_map.values())
207
- response = secretsmanager.batch_get_secret_value(SecretIdList=arns)
208
- if errors := response["Errors"]:
209
- msg = f"Failed to retrieve secrets: {errors!r}"
210
- raise ValueError(msg)
211
-
212
- # Parse the secrets
213
- secrets = response["SecretValues"]
214
- result = {}
215
- for secret in secrets:
216
- arn = secret["ARN"]
217
- order_key = next(key for key, value in secrets_map.items() if value == arn)
218
- data = json.loads(secret["SecretString"])
219
- if not isinstance(data, dict):
220
- msg = f"Secret data must be a valid dictionary, but got: {type(data)!r}"
221
- raise TypeError(msg)
222
-
223
- result[order_key] = data
224
-
225
- return result
226
-
227
- def _retrieve_parameters(self, parameters_map: dict[str, _ARN]) -> dict[str, _Variables]:
228
- """Retrieve the parameters from AWS SSM Parameter Store."""
229
- if not parameters_map:
230
- return {}
231
-
232
- ssm = boto3.client("ssm")
233
-
234
- # Retrieve the parameters
235
- parameter_names = list(parameters_map.values())
236
- response = ssm.get_parameters(Names=parameter_names, WithDecryption=True)
237
- if errors := response["InvalidParameters"]:
238
- msg = f"Failed to retrieve parameters: {errors!r}"
239
- raise ValueError(msg)
240
-
241
- # Parse the parameters
242
- parameters = response["Parameters"]
243
- result = {}
244
- for parameter in parameters:
245
- arn = parameter["ARN"]
246
- order_key = next(key for key, value in parameters_map.items() if value == arn)
247
- data = json.loads(parameter["Value"])
248
- if not isinstance(data, dict):
249
- msg = f"Parameter data must be a valid dictionary, but got: {type(data)!r}"
250
- raise TypeError(msg)
251
-
252
- result[order_key] = data
253
-
254
- return result