aws-annoying 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. aws_annoying/cli/app.py +81 -0
  2. aws_annoying/cli/ecs/__init__.py +3 -0
  3. aws_annoying/cli/ecs/_app.py +9 -0
  4. aws_annoying/cli/{ecs_task_definition_lifecycle.py → ecs/task_definition_lifecycle.py} +18 -13
  5. aws_annoying/cli/ecs/wait_for_deployment.py +158 -0
  6. aws_annoying/cli/load_variables.py +20 -25
  7. aws_annoying/cli/logging_handler.py +52 -0
  8. aws_annoying/cli/main.py +1 -1
  9. aws_annoying/cli/mfa/configure.py +21 -12
  10. aws_annoying/cli/session_manager/_common.py +1 -32
  11. aws_annoying/cli/session_manager/install.py +8 -5
  12. aws_annoying/cli/session_manager/port_forward.py +22 -12
  13. aws_annoying/cli/session_manager/start.py +13 -5
  14. aws_annoying/cli/session_manager/stop.py +9 -7
  15. aws_annoying/ecs/__init__.py +25 -0
  16. aws_annoying/ecs/check.py +39 -0
  17. aws_annoying/ecs/common.py +8 -0
  18. aws_annoying/ecs/errors.py +14 -0
  19. aws_annoying/ecs/wait_for.py +190 -0
  20. aws_annoying/{mfa.py → mfa_config.py} +7 -2
  21. aws_annoying/session_manager/session_manager.py +2 -4
  22. aws_annoying/session_manager/shortcuts.py +10 -6
  23. aws_annoying/utils/downloader.py +1 -8
  24. aws_annoying/utils/ec2.py +33 -0
  25. aws_annoying/utils/platform.py +11 -0
  26. aws_annoying/utils/timeout.py +85 -0
  27. aws_annoying/{variables.py → variable_loader.py} +11 -16
  28. {aws_annoying-0.5.0.dist-info → aws_annoying-0.7.0.dist-info}/METADATA +48 -3
  29. aws_annoying-0.7.0.dist-info/RECORD +42 -0
  30. aws_annoying-0.5.0.dist-info/RECORD +0 -31
  31. {aws_annoying-0.5.0.dist-info → aws_annoying-0.7.0.dist-info}/WHEEL +0 -0
  32. {aws_annoying-0.5.0.dist-info → aws_annoying-0.7.0.dist-info}/entry_points.txt +0 -0
  33. {aws_annoying-0.5.0.dist-info → aws_annoying-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,13 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
4
+
3
5
  import typer
4
- from rich import print # noqa: A004
5
6
 
6
7
  from aws_annoying.utils.downloader import TQDMDownloader
7
8
 
8
9
  from ._app import session_manager_app
9
10
  from ._common import SessionManager
10
11
 
12
+ logger = logging.getLogger(__name__)
13
+
11
14
 
12
15
  # https://docs.aws.amazon.com/systems-manager/latest/userguide/session-manager-working-with-install-plugin.html
13
16
  @session_manager_app.command()
@@ -23,17 +26,17 @@ def install(
23
26
  # Check session-manager-plugin already installed
24
27
  is_installed, binary_path, version = session_manager.verify_installation()
25
28
  if is_installed:
26
- print(f"Session Manager plugin is already installed at {binary_path} (version: {version})")
29
+ logger.info("Session Manager plugin is already installed at %s (version: %s)", binary_path, version)
27
30
  return
28
31
 
29
32
  # Install session-manager-plugin
30
- print("⬇️ Installing AWS Session Manager plugin. You could be prompted for admin privileges request.")
33
+ logger.warning("Installing AWS Session Manager plugin. You could be prompted for admin privileges request.")
31
34
  session_manager.install(confirm=yes, downloader=TQDMDownloader())
32
35
 
33
36
  # Verify installation
34
37
  is_installed, binary_path, version = session_manager.verify_installation()
35
38
  if not is_installed:
36
- print("Installation failed. Session Manager plugin not found.")
39
+ logger.error("Installation failed. Session Manager plugin not found.")
37
40
  raise typer.Exit(1)
38
41
 
39
- print(f"Session Manager plugin successfully installed at {binary_path} (version: {version})")
42
+ logger.info("Session Manager plugin successfully installed at %s (version: %s)", binary_path, version)
@@ -1,15 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  import os
4
5
  import signal
5
6
  import subprocess
6
7
  from pathlib import Path # noqa: TC003
7
8
 
8
9
  import typer
9
- from rich import print # noqa: A004
10
+
11
+ from aws_annoying.utils.ec2 import get_instance_id_by_name
10
12
 
11
13
  from ._app import session_manager_app
12
- from ._common import SessionManager, get_instance_id_by_name
14
+ from ._common import SessionManager
15
+
16
+ logger = logging.getLogger(__name__)
13
17
 
14
18
 
15
19
  # https://docs.aws.amazon.com/systems-manager/latest/userguide/session-manager-working-with-install-plugin.html
@@ -59,30 +63,30 @@ def port_forward( # noqa: PLR0913
59
63
  # Check if the PID file already exists
60
64
  if pid_file.exists():
61
65
  if not terminate_running_process:
62
- print("🚫 PID file already exists.")
66
+ logger.error("PID file already exists.")
63
67
  raise typer.Exit(1)
64
68
 
65
69
  pid_content = pid_file.read_text()
66
70
  try:
67
71
  existing_pid = int(pid_content)
68
72
  except ValueError:
69
- print(f"🚫 PID file content is invalid; expected integer, but got: {type(pid_content)}")
73
+ logger.error("PID file content is invalid; expected integer, but got: %r", type(pid_content)) # noqa: TRY400
70
74
  raise typer.Exit(1) from None
71
75
 
72
76
  try:
73
- print(f"⚠️ Terminating running process with PID {existing_pid}.")
77
+ logger.warning("Terminating running process with PID %d.", existing_pid)
74
78
  os.kill(existing_pid, signal.SIGTERM)
75
79
  pid_file.write_text("") # Clear the PID file
76
80
  except ProcessLookupError:
77
- print(f"⚠️ Tried to terminate process with PID {existing_pid} but does not exist.")
81
+ logger.warning("Tried to terminate process with PID %d but does not exist.", existing_pid)
78
82
 
79
83
  # Resolve the instance name or ID
80
84
  instance_id = get_instance_id_by_name(through)
81
85
  if instance_id:
82
- print(f"Instance ID resolved: [bold]{instance_id}[/bold]")
86
+ logger.info("Instance ID resolved: [bold]%s[/bold]", instance_id)
83
87
  target = instance_id
84
88
  else:
85
- print(f"🚫 Instance with name '{through}' not found.")
89
+ logger.info("Instance with name '%s' not found.", through)
86
90
  raise typer.Exit(1)
87
91
 
88
92
  # Initiate the session
@@ -102,8 +106,10 @@ def port_forward( # noqa: PLR0913
102
106
  else:
103
107
  stdout = subprocess.DEVNULL
104
108
 
105
- print(
106
- f"🚀 Starting port forwarding session through [bold]{through}[/bold] with reason: [italic]{reason!r}[/italic].",
109
+ logger.info(
110
+ "Starting port forwarding session through [bold]%s[/bold] with reason: [italic]%r[/italic].",
111
+ through,
112
+ reason,
107
113
  )
108
114
  proc = subprocess.Popen( # noqa: S603
109
115
  command,
@@ -112,8 +118,12 @@ def port_forward( # noqa: PLR0913
112
118
  text=True,
113
119
  close_fds=False, # FD inherited from parent process
114
120
  )
115
- print(f"✅ Session Manager Plugin started with PID {proc.pid}. Outputs will be logged to {log_file.absolute()}.")
121
+ logger.info(
122
+ "Session Manager Plugin started with PID %d. Outputs will be logged to %s.",
123
+ proc.pid,
124
+ log_file.absolute(),
125
+ )
116
126
 
117
127
  # Write the PID to the file
118
128
  pid_file.write_text(str(proc.pid))
119
- print(f"💾 PID file written to {pid_file.absolute()}.")
129
+ logger.info("PID file written to %s.", pid_file.absolute())
@@ -1,12 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  import os
4
5
 
5
6
  import typer
6
- from rich import print # noqa: A004
7
+
8
+ from aws_annoying.utils.ec2 import get_instance_id_by_name
7
9
 
8
10
  from ._app import session_manager_app
9
- from ._common import SessionManager, get_instance_id_by_name
11
+ from ._common import SessionManager
12
+
13
+ logger = logging.getLogger(__name__)
10
14
 
11
15
  # TODO(lasuillard): ECS support (#24)
12
16
  # TODO(lasuillard): Interactive instance selection
@@ -30,14 +34,18 @@ def start(
30
34
  # Resolve the instance name or ID
31
35
  instance_id = get_instance_id_by_name(target)
32
36
  if instance_id:
33
- print(f"Instance ID resolved: [bold]{instance_id}[/bold]")
37
+ logger.info("Instance ID resolved: [bold]%s[/bold]", instance_id)
34
38
  target = instance_id
35
39
  else:
36
- print(f"🚫 Instance with name '{target}' not found.")
40
+ logger.info("Instance with name '%s' not found.", target)
37
41
  raise typer.Exit(1)
38
42
 
39
43
  # Start the session, replacing the current process
40
- print(f"🚀 Starting session to target [bold]{target}[/bold] with reason: [italic]{reason!r}[/italic].")
44
+ logger.info(
45
+ "Starting session to target [bold]%s[/bold] with reason: [italic]%r[/italic].",
46
+ target,
47
+ reason,
48
+ )
41
49
  command = session_manager.build_command(
42
50
  target=target,
43
51
  document_name="SSM-SessionManagerRunShell",
@@ -1,14 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  import os
4
5
  import signal
5
6
  from pathlib import Path # noqa: TC003
6
7
 
7
8
  import typer
8
- from rich import print # noqa: A004
9
9
 
10
10
  from ._app import session_manager_app
11
11
 
12
+ logger = logging.getLogger(__name__)
13
+
12
14
 
13
15
  @session_manager_app.command()
14
16
  def stop(
@@ -24,7 +26,7 @@ def stop(
24
26
  """Stop running session for PID file."""
25
27
  # Check if PID file exists
26
28
  if not pid_file.is_file():
27
- print(f"PID file not found: {pid_file}")
29
+ logger.error("PID file not found: %s", pid_file)
28
30
  raise typer.Exit(1)
29
31
 
30
32
  # Read PID from file
@@ -32,19 +34,19 @@ def stop(
32
34
  try:
33
35
  pid = int(pid_content)
34
36
  except ValueError:
35
- print(f"🚫 PID file content is invalid; expected integer, but got: {type(pid_content)}")
37
+ logger.error("PID file content is invalid; expected integer, but got: %s", type(pid_content)) # noqa: TRY400
36
38
  raise typer.Exit(1) from None
37
39
 
38
40
  # Send SIGTERM to the process
39
41
  try:
40
- print(f"⚠️ Terminating running process with PID {pid}.")
42
+ logger.warning("Terminating running process with PID %d.", pid)
41
43
  os.kill(pid, signal.SIGTERM)
42
44
  except ProcessLookupError:
43
- print(f"Tried to terminate process with PID {pid} but does not exist.")
45
+ logger.warning("Tried to terminate process with PID %d but does not exist.", pid)
44
46
 
45
47
  # Remove the PID file
46
48
  if remove:
47
- print(f"Removed the PID file {pid_file}.")
49
+ logger.info("Removed the PID file %s.", pid_file)
48
50
  pid_file.unlink()
49
51
 
50
- print("Terminated the session successfully.")
52
+ logger.info("Terminated the session successfully.")
@@ -0,0 +1,25 @@
1
+ from .check import check_service_task_definition
2
+ from .common import ECSServiceRef
3
+ from .errors import (
4
+ DeploymentFailedError,
5
+ NoRunningDeploymentError,
6
+ ServiceTaskDefinitionAssertionError,
7
+ WaitForDeploymentError,
8
+ )
9
+ from .wait_for import (
10
+ wait_for_deployment_complete,
11
+ wait_for_deployment_start,
12
+ wait_for_service_stability,
13
+ )
14
+
15
+ __all__ = (
16
+ "DeploymentFailedError",
17
+ "ECSServiceRef",
18
+ "NoRunningDeploymentError",
19
+ "ServiceTaskDefinitionAssertionError",
20
+ "WaitForDeploymentError",
21
+ "check_service_task_definition",
22
+ "wait_for_deployment_complete",
23
+ "wait_for_deployment_start",
24
+ "wait_for_service_stability",
25
+ )
@@ -0,0 +1,39 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ import boto3
7
+
8
+ if TYPE_CHECKING:
9
+ from .common import ECSServiceRef
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def check_service_task_definition(
15
+ service_ref: ECSServiceRef,
16
+ *,
17
+ session: boto3.session.Session | None = None,
18
+ expect: str,
19
+ ) -> tuple[bool, str]:
20
+ """Check the service's current task definition matches the expected one.
21
+
22
+ Args:
23
+ service_ref: The ECS service reference containing the cluster and service names.
24
+ session: The boto3 session to use for the ECS client.
25
+ expect: The ARN of expected task definition.
26
+
27
+ Returns:
28
+ A tuple containing a boolean indicating whether the task definition matches the expected one
29
+ and the current task definition ARN.
30
+ """
31
+ session = session or boto3.session.Session()
32
+ ecs = session.client("ecs")
33
+
34
+ service_detail = ecs.describe_services(cluster=service_ref.cluster, services=[service_ref.service])["services"][0]
35
+ current_task_definition_arn = service_detail["taskDefinition"]
36
+ if current_task_definition_arn != expect:
37
+ return (False, current_task_definition_arn)
38
+
39
+ return (True, current_task_definition_arn)
@@ -0,0 +1,8 @@
1
+ from typing import NamedTuple
2
+
3
+
4
+ class ECSServiceRef(NamedTuple):
5
+ """Reference to an ECS service."""
6
+
7
+ cluster: str
8
+ service: str
@@ -0,0 +1,14 @@
1
+ class WaitForDeploymentError(Exception):
2
+ """Base class for all deployment waiter errors."""
3
+
4
+
5
+ class NoRunningDeploymentError(WaitForDeploymentError):
6
+ """No running deployment found for the service."""
7
+
8
+
9
+ class DeploymentFailedError(WaitForDeploymentError):
10
+ """Deployment failed."""
11
+
12
+
13
+ class ServiceTaskDefinitionAssertionError(WaitForDeploymentError):
14
+ """Service task definition does not match the expected one."""
@@ -0,0 +1,190 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from datetime import datetime, timezone
5
+ from time import sleep
6
+ from typing import TYPE_CHECKING
7
+
8
+ import boto3
9
+ import botocore.exceptions
10
+
11
+ from .errors import NoRunningDeploymentError
12
+
13
+ if TYPE_CHECKING:
14
+ from .common import ECSServiceRef
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def wait_for_deployment_start(
20
+ service_ref: ECSServiceRef,
21
+ *,
22
+ session: boto3.session.Session | None = None,
23
+ wait_for_start: bool,
24
+ polling_interval: int = 5,
25
+ max_attempts: int | None = None,
26
+ ) -> str:
27
+ """Wait for the ECS deployment to start.
28
+
29
+ Args:
30
+ service_ref: The ECS service reference containing the cluster and service names.
31
+ session: The boto3 session to use for the ECS client.
32
+ wait_for_start: Whether to wait for the deployment to start.
33
+ polling_interval: The interval between any polling attempts, in seconds.
34
+ max_attempts: The maximum number of attempts to wait for the deployment to start.
35
+
36
+ Raises:
37
+ NoRunningDeploymentError: If no running deployments are found and `wait_for_start` is False.
38
+
39
+ Returns:
40
+ The ARN of the latest deployment for the service.
41
+ """
42
+ session = session or boto3.session.Session()
43
+ ecs = session.client("ecs")
44
+
45
+ if wait_for_start:
46
+ logger.warning("`wait_for_start` is set, will wait for a new deployment to start.")
47
+
48
+ attempts = 0
49
+ while True: # do-while
50
+ # Do
51
+ running_deployments = ecs.list_service_deployments(
52
+ cluster=service_ref.cluster,
53
+ service=service_ref.service,
54
+ status=["PENDING", "IN_PROGRESS"],
55
+ )["serviceDeployments"]
56
+
57
+ # While
58
+ if running_deployments:
59
+ logger.debug("Found %d running deployments for service. Exiting loop.", len(running_deployments))
60
+ break
61
+
62
+ if not wait_for_start:
63
+ logger.debug("`wait_for_start` is off, no need to wait for a new deployment to start.")
64
+ break
65
+
66
+ if max_attempts and attempts >= max_attempts:
67
+ logger.debug("Max attempts exceeded while waiting for a new deployment to start.")
68
+ break
69
+
70
+ logger.debug(
71
+ "(%d-th attempt) No running deployments found for service. Start waiting for a new deployment.",
72
+ attempts + 1,
73
+ )
74
+
75
+ sleep(polling_interval)
76
+ attempts += 1
77
+
78
+ if not running_deployments:
79
+ msg = "No running deployments found for service."
80
+ raise NoRunningDeploymentError(msg)
81
+
82
+ latest_deployment = max(
83
+ running_deployments,
84
+ key=lambda dep: dep.get(
85
+ "startedAt",
86
+ datetime.min.replace(tzinfo=timezone.utc),
87
+ ),
88
+ )
89
+ if len(running_deployments) > 1:
90
+ logger.warning(
91
+ "%d running deployments found for service. Using most recently started deployment: %s",
92
+ len(running_deployments),
93
+ latest_deployment["serviceDeploymentArn"],
94
+ )
95
+
96
+ return latest_deployment["serviceDeploymentArn"]
97
+
98
+
99
+ def wait_for_deployment_complete(
100
+ deployment_arn: str,
101
+ *,
102
+ session: boto3.session.Session | None = None,
103
+ polling_interval: int = 5,
104
+ max_attempts: int | None = None,
105
+ ) -> tuple[bool, str]:
106
+ """Wait for the ECS deployment to complete.
107
+
108
+ Args:
109
+ deployment_arn: The ARN of the deployment to wait for.
110
+ session: The boto3 session to use for the ECS client.
111
+ polling_interval: The interval between any polling attempts, in seconds.
112
+ max_attempts: The maximum number of attempts to wait for the deployment to complete.
113
+
114
+ Returns:
115
+ A tuple containing a boolean indicating whether the deployment succeeded and the status of the deployment.
116
+ """
117
+ session = session or boto3.session.Session()
118
+ ecs = session.client("ecs")
119
+
120
+ attempts = 0
121
+ while (max_attempts is None) or (attempts <= max_attempts):
122
+ latest_deployment = ecs.describe_service_deployments(serviceDeploymentArns=[deployment_arn])[
123
+ "serviceDeployments"
124
+ ][0]
125
+ status = latest_deployment["status"]
126
+ if status == "SUCCESSFUL":
127
+ return (True, status)
128
+
129
+ if status in ("PENDING", "IN_PROGRESS"):
130
+ logger.debug(
131
+ "(%d-th attempt) Deployment in progress... (%s)",
132
+ attempts + 1,
133
+ status,
134
+ )
135
+ else:
136
+ break
137
+
138
+ sleep(polling_interval)
139
+ attempts += 1
140
+
141
+ return (False, status)
142
+
143
+
144
+ def wait_for_service_stability(
145
+ service_ref: ECSServiceRef,
146
+ *,
147
+ session: boto3.session.Session | None = None,
148
+ polling_interval: int = 5,
149
+ max_attempts: int | None = None,
150
+ ) -> bool:
151
+ """Wait for the ECS service to be stable.
152
+
153
+ Args:
154
+ service_ref: The ECS service reference containing the cluster and service names.
155
+ session: The boto3 session to use for the ECS client.
156
+ polling_interval: The interval between any polling attempts, in seconds.
157
+ max_attempts: The maximum number of attempts to wait for the service to be stable.
158
+
159
+ Returns:
160
+ A boolean indicating whether the service is stable.
161
+ """
162
+ session = session or boto3.session.Session()
163
+ ecs = session.client("ecs")
164
+
165
+ # TODO(lasuillard): Likely to be a problem in some cases: https://github.com/boto/botocore/issues/3314
166
+ stability_waiter = ecs.get_waiter("services_stable")
167
+
168
+ attempts = 0
169
+ while (max_attempts is None) or (attempts <= max_attempts):
170
+ logger.debug(
171
+ "(%d-th attempt) Waiting for service %s to be stable...",
172
+ attempts + 1,
173
+ service_ref.service,
174
+ )
175
+ try:
176
+ stability_waiter.wait(
177
+ cluster=service_ref.cluster,
178
+ services=[service_ref.service],
179
+ WaiterConfig={"Delay": polling_interval, "MaxAttempts": 1},
180
+ )
181
+ except botocore.exceptions.WaiterError as err:
182
+ if err.kwargs["reason"] != "Max attempts exceeded":
183
+ raise
184
+ else:
185
+ return True
186
+
187
+ sleep(polling_interval)
188
+ attempts += 1
189
+
190
+ return False
@@ -1,13 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import configparser
4
+ import logging
4
5
  from pathlib import Path # noqa: TC003
5
6
  from typing import Optional
6
7
 
7
8
  from pydantic import BaseModel, ConfigDict
8
9
 
9
- # TODO(lasuillard): Need some refactoring (configurator class)
10
- # TODO(lasuillard): Put some logging
10
+ logger = logging.getLogger(__name__)
11
11
 
12
12
 
13
13
  class MfaConfig(BaseModel):
@@ -30,9 +30,12 @@ class MfaConfig(BaseModel):
30
30
  with path.open("w") as f:
31
31
  config_ini.write(f)
32
32
 
33
+ logger.debug("Saved config to %s with section %s", path, section_key)
34
+
33
35
  @classmethod
34
36
  def from_ini_file(cls, path: Path, section_key: str) -> tuple[MfaConfig, bool]:
35
37
  """Load configuration from an AWS config file, with boolean indicating if the config already exists."""
38
+ logger.debug("Loading config from %s with section %s", path, section_key)
36
39
  config_ini = configparser.ConfigParser()
37
40
  config_ini.read(path)
38
41
  if config_ini.has_section(section_key):
@@ -52,3 +55,5 @@ def update_credentials(path: Path, profile: str, *, access_key: str, secret_key:
52
55
  credentials_ini[profile]["aws_session_token"] = session_token
53
56
  with path.open("w") as f:
54
57
  credentials_ini.write(f)
58
+
59
+ logger.debug("Updated credentials file %s with profile %s", path, profile)
@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, NamedTuple
13
13
 
14
14
  import boto3
15
15
 
16
- from aws_annoying.utils.platform import command_as_root, is_root, os_release
16
+ from aws_annoying.utils.platform import command_as_root, is_root, is_windows, os_release
17
17
 
18
18
  from .errors import PluginNotInstalledError, UnsupportedPlatformError
19
19
 
@@ -22,8 +22,6 @@ if TYPE_CHECKING:
22
22
 
23
23
  logger = logging.getLogger(__name__)
24
24
 
25
- # TODO(lasuillard): Platform checking is spread everywhere, should be moved to a single place
26
-
27
25
 
28
26
  class SessionManager:
29
27
  """AWS Session Manager plugin manager."""
@@ -65,7 +63,7 @@ class SessionManager:
65
63
  """Get the path to the session-manager-plugin binary."""
66
64
  binary_path_str = shutil.which("session-manager-plugin")
67
65
  if not binary_path_str:
68
- if platform.system() == "Windows":
66
+ if is_windows():
69
67
  # Windows: use the default installation path
70
68
  binary_path = (
71
69
  Path(os.environ["ProgramFiles"]) # noqa: SIM112
@@ -5,6 +5,8 @@ import subprocess
5
5
  from contextlib import contextmanager
6
6
  from typing import TYPE_CHECKING
7
7
 
8
+ from aws_annoying.utils.timeout import Timeout
9
+
8
10
  from .session_manager import SessionManager
9
11
 
10
12
  if TYPE_CHECKING:
@@ -15,13 +17,14 @@ logger = logging.getLogger(__name__)
15
17
 
16
18
 
17
19
  @contextmanager
18
- def port_forward(
20
+ def port_forward( # noqa: PLR0913
19
21
  *,
20
22
  through: str,
21
23
  local_port: int,
22
24
  remote_host: str,
23
25
  remote_port: int,
24
26
  reason: str | None = None,
27
+ start_timeout: int | None = None,
25
28
  ) -> Iterator[subprocess.Popen[str]]:
26
29
  """Context manager for port forwarding sessions.
27
30
 
@@ -31,6 +34,7 @@ def port_forward(
31
34
  remote_host: The remote host to connect to.
32
35
  remote_port: The remote port to connect to.
33
36
  reason: The reason for starting the session.
37
+ start_timeout: The timeout in seconds to wait for the session to start.
34
38
 
35
39
  Returns:
36
40
  The command to start the session.
@@ -61,11 +65,11 @@ def port_forward(
61
65
 
62
66
  # Wait for the session to start
63
67
  # ? Not sure this is trustworthy health check
64
- # TODO(lasuillard): Need timeout to avoid hanging forever
65
- for line in proc.stdout:
66
- if "Waiting for connections..." in line:
67
- logger.info("Session started successfully.")
68
- break
68
+ with Timeout(start_timeout):
69
+ for line in proc.stdout:
70
+ if "Waiting for connections..." in line:
71
+ logger.info("Session started successfully.")
72
+ break
69
73
 
70
74
  yield proc
71
75
  finally:
@@ -42,14 +42,7 @@ class TQDMDownloader(AbstractDownloader):
42
42
  total_size = int(response.headers.get("content-length", 0))
43
43
  with (
44
44
  to.open("wb") as f,
45
- tqdm(
46
- # Make the URL less verbose in the progress bar
47
- desc=url.replace("https://s3.amazonaws.com/session-manager-downloads/plugin", "..."),
48
- total=total_size,
49
- unit="iB",
50
- unit_scale=True,
51
- unit_divisor=1_024,
52
- ) as pbar,
45
+ tqdm(desc=url, total=total_size, unit="iB", unit_scale=True, unit_divisor=1_024) as pbar,
53
46
  ):
54
47
  for chunk in response.iter_content(chunk_size=8_192):
55
48
  size = f.write(chunk)
@@ -0,0 +1,33 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+
5
+ import boto3
6
+
7
+
8
+ def get_instance_id_by_name(name_or_id: str, *, session: boto3.session.Session | None = None) -> str | None:
9
+ """Get the EC2 instance ID by name or ID.
10
+
11
+ Be aware that this function will only return the first instance found
12
+ with the given name, no matter how many instances are found.
13
+
14
+ Args:
15
+ name_or_id: The name or ID of the EC2 instance.
16
+ session: The boto3 session to use. If not provided, a new session will be created.
17
+
18
+ Returns:
19
+ The instance ID if found, otherwise `None`.
20
+ """
21
+ if re.match(r"^m?i-[0-9a-f]+$", name_or_id):
22
+ return name_or_id
23
+
24
+ session = session or boto3.session.Session()
25
+ ec2 = session.client("ec2")
26
+
27
+ response = ec2.describe_instances(Filters=[{"Name": "tag:Name", "Values": [name_or_id]}])
28
+ reservations = response["Reservations"]
29
+ if not reservations or not reservations[0]["Instances"]:
30
+ return None
31
+
32
+ instances = reservations[0]["Instances"]
33
+ return str(instances[0]["InstanceId"])