aws-annoying 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aws_annoying/cli/app.py +81 -0
- aws_annoying/cli/ecs/__init__.py +3 -0
- aws_annoying/cli/ecs/_app.py +9 -0
- aws_annoying/cli/{ecs_task_definition_lifecycle.py → ecs/task_definition_lifecycle.py} +18 -13
- aws_annoying/cli/ecs/wait_for_deployment.py +94 -0
- aws_annoying/cli/load_variables.py +22 -22
- aws_annoying/cli/logging_handler.py +52 -0
- aws_annoying/cli/main.py +1 -1
- aws_annoying/cli/mfa/configure.py +21 -12
- aws_annoying/cli/session_manager/_common.py +1 -32
- aws_annoying/cli/session_manager/install.py +8 -5
- aws_annoying/cli/session_manager/port_forward.py +22 -12
- aws_annoying/cli/session_manager/start.py +13 -5
- aws_annoying/cli/session_manager/stop.py +9 -7
- aws_annoying/ecs/__init__.py +17 -0
- aws_annoying/ecs/common.py +8 -0
- aws_annoying/ecs/deployment_waiter.py +274 -0
- aws_annoying/ecs/errors.py +14 -0
- aws_annoying/{mfa.py → mfa_config.py} +7 -2
- aws_annoying/session_manager/session_manager.py +2 -4
- aws_annoying/session_manager/shortcuts.py +10 -6
- aws_annoying/utils/ec2.py +36 -0
- aws_annoying/utils/platform.py +11 -0
- aws_annoying/utils/timeout.py +88 -0
- aws_annoying/{variables.py → variable_loader.py} +11 -16
- {aws_annoying-0.5.0.dist-info → aws_annoying-0.6.0.dist-info}/METADATA +47 -2
- aws_annoying-0.6.0.dist-info/RECORD +41 -0
- aws_annoying-0.5.0.dist-info/RECORD +0 -31
- {aws_annoying-0.5.0.dist-info → aws_annoying-0.6.0.dist-info}/WHEEL +0 -0
- {aws_annoying-0.5.0.dist-info → aws_annoying-0.6.0.dist-info}/entry_points.txt +0 -0
- {aws_annoying-0.5.0.dist-info → aws_annoying-0.6.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,15 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
import os
|
|
4
5
|
import signal
|
|
5
6
|
import subprocess
|
|
6
7
|
from pathlib import Path # noqa: TC003
|
|
7
8
|
|
|
8
9
|
import typer
|
|
9
|
-
|
|
10
|
+
|
|
11
|
+
from aws_annoying.utils.ec2 import get_instance_id_by_name
|
|
10
12
|
|
|
11
13
|
from ._app import session_manager_app
|
|
12
|
-
from ._common import SessionManager
|
|
14
|
+
from ._common import SessionManager
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
13
17
|
|
|
14
18
|
|
|
15
19
|
# https://docs.aws.amazon.com/systems-manager/latest/userguide/session-manager-working-with-install-plugin.html
|
|
@@ -59,30 +63,30 @@ def port_forward( # noqa: PLR0913
|
|
|
59
63
|
# Check if the PID file already exists
|
|
60
64
|
if pid_file.exists():
|
|
61
65
|
if not terminate_running_process:
|
|
62
|
-
|
|
66
|
+
logger.error("PID file already exists.")
|
|
63
67
|
raise typer.Exit(1)
|
|
64
68
|
|
|
65
69
|
pid_content = pid_file.read_text()
|
|
66
70
|
try:
|
|
67
71
|
existing_pid = int(pid_content)
|
|
68
72
|
except ValueError:
|
|
69
|
-
|
|
73
|
+
logger.error("PID file content is invalid; expected integer, but got: %r", type(pid_content)) # noqa: TRY400
|
|
70
74
|
raise typer.Exit(1) from None
|
|
71
75
|
|
|
72
76
|
try:
|
|
73
|
-
|
|
77
|
+
logger.warning("Terminating running process with PID %d.", existing_pid)
|
|
74
78
|
os.kill(existing_pid, signal.SIGTERM)
|
|
75
79
|
pid_file.write_text("") # Clear the PID file
|
|
76
80
|
except ProcessLookupError:
|
|
77
|
-
|
|
81
|
+
logger.warning("Tried to terminate process with PID %d but does not exist.", existing_pid)
|
|
78
82
|
|
|
79
83
|
# Resolve the instance name or ID
|
|
80
84
|
instance_id = get_instance_id_by_name(through)
|
|
81
85
|
if instance_id:
|
|
82
|
-
|
|
86
|
+
logger.info("Instance ID resolved: [bold]%s[/bold]", instance_id)
|
|
83
87
|
target = instance_id
|
|
84
88
|
else:
|
|
85
|
-
|
|
89
|
+
logger.info("Instance with name '%s' not found.", through)
|
|
86
90
|
raise typer.Exit(1)
|
|
87
91
|
|
|
88
92
|
# Initiate the session
|
|
@@ -102,8 +106,10 @@ def port_forward( # noqa: PLR0913
|
|
|
102
106
|
else:
|
|
103
107
|
stdout = subprocess.DEVNULL
|
|
104
108
|
|
|
105
|
-
|
|
106
|
-
|
|
109
|
+
logger.info(
|
|
110
|
+
"Starting port forwarding session through [bold]%s[/bold] with reason: [italic]%r[/italic].",
|
|
111
|
+
through,
|
|
112
|
+
reason,
|
|
107
113
|
)
|
|
108
114
|
proc = subprocess.Popen( # noqa: S603
|
|
109
115
|
command,
|
|
@@ -112,8 +118,12 @@ def port_forward( # noqa: PLR0913
|
|
|
112
118
|
text=True,
|
|
113
119
|
close_fds=False, # FD inherited from parent process
|
|
114
120
|
)
|
|
115
|
-
|
|
121
|
+
logger.info(
|
|
122
|
+
"Session Manager Plugin started with PID %d. Outputs will be logged to %s.",
|
|
123
|
+
proc.pid,
|
|
124
|
+
log_file.absolute(),
|
|
125
|
+
)
|
|
116
126
|
|
|
117
127
|
# Write the PID to the file
|
|
118
128
|
pid_file.write_text(str(proc.pid))
|
|
119
|
-
|
|
129
|
+
logger.info("PID file written to %s.", pid_file.absolute())
|
|
@@ -1,12 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
import os
|
|
4
5
|
|
|
5
6
|
import typer
|
|
6
|
-
|
|
7
|
+
|
|
8
|
+
from aws_annoying.utils.ec2 import get_instance_id_by_name
|
|
7
9
|
|
|
8
10
|
from ._app import session_manager_app
|
|
9
|
-
from ._common import SessionManager
|
|
11
|
+
from ._common import SessionManager
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
10
14
|
|
|
11
15
|
# TODO(lasuillard): ECS support (#24)
|
|
12
16
|
# TODO(lasuillard): Interactive instance selection
|
|
@@ -30,14 +34,18 @@ def start(
|
|
|
30
34
|
# Resolve the instance name or ID
|
|
31
35
|
instance_id = get_instance_id_by_name(target)
|
|
32
36
|
if instance_id:
|
|
33
|
-
|
|
37
|
+
logger.info("Instance ID resolved: [bold]%s[/bold]", instance_id)
|
|
34
38
|
target = instance_id
|
|
35
39
|
else:
|
|
36
|
-
|
|
40
|
+
logger.info("Instance with name '%s' not found.", target)
|
|
37
41
|
raise typer.Exit(1)
|
|
38
42
|
|
|
39
43
|
# Start the session, replacing the current process
|
|
40
|
-
|
|
44
|
+
logger.info(
|
|
45
|
+
"Starting session to target [bold]%s[/bold] with reason: [italic]%r[/italic].",
|
|
46
|
+
target,
|
|
47
|
+
reason,
|
|
48
|
+
)
|
|
41
49
|
command = session_manager.build_command(
|
|
42
50
|
target=target,
|
|
43
51
|
document_name="SSM-SessionManagerRunShell",
|
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
import os
|
|
4
5
|
import signal
|
|
5
6
|
from pathlib import Path # noqa: TC003
|
|
6
7
|
|
|
7
8
|
import typer
|
|
8
|
-
from rich import print # noqa: A004
|
|
9
9
|
|
|
10
10
|
from ._app import session_manager_app
|
|
11
11
|
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
12
14
|
|
|
13
15
|
@session_manager_app.command()
|
|
14
16
|
def stop(
|
|
@@ -24,7 +26,7 @@ def stop(
|
|
|
24
26
|
"""Stop running session for PID file."""
|
|
25
27
|
# Check if PID file exists
|
|
26
28
|
if not pid_file.is_file():
|
|
27
|
-
|
|
29
|
+
logger.error("PID file not found: %s", pid_file)
|
|
28
30
|
raise typer.Exit(1)
|
|
29
31
|
|
|
30
32
|
# Read PID from file
|
|
@@ -32,19 +34,19 @@ def stop(
|
|
|
32
34
|
try:
|
|
33
35
|
pid = int(pid_content)
|
|
34
36
|
except ValueError:
|
|
35
|
-
|
|
37
|
+
logger.error("PID file content is invalid; expected integer, but got: %s", type(pid_content)) # noqa: TRY400
|
|
36
38
|
raise typer.Exit(1) from None
|
|
37
39
|
|
|
38
40
|
# Send SIGTERM to the process
|
|
39
41
|
try:
|
|
40
|
-
|
|
42
|
+
logger.warning("Terminating running process with PID %d.", pid)
|
|
41
43
|
os.kill(pid, signal.SIGTERM)
|
|
42
44
|
except ProcessLookupError:
|
|
43
|
-
|
|
45
|
+
logger.warning("Tried to terminate process with PID %d but does not exist.", pid)
|
|
44
46
|
|
|
45
47
|
# Remove the PID file
|
|
46
48
|
if remove:
|
|
47
|
-
|
|
49
|
+
logger.info("Removed the PID file %s.", pid_file)
|
|
48
50
|
pid_file.unlink()
|
|
49
51
|
|
|
50
|
-
|
|
52
|
+
logger.info("Terminated the session successfully.")
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .common import ECSServiceRef
|
|
2
|
+
from .deployment_waiter import ECSDeploymentWaiter
|
|
3
|
+
from .errors import (
|
|
4
|
+
DeploymentFailedError,
|
|
5
|
+
NoRunningDeploymentError,
|
|
6
|
+
ServiceTaskDefinitionAssertionError,
|
|
7
|
+
WaitForDeploymentError,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = (
|
|
11
|
+
"DeploymentFailedError",
|
|
12
|
+
"ECSDeploymentWaiter",
|
|
13
|
+
"ECSServiceRef",
|
|
14
|
+
"NoRunningDeploymentError",
|
|
15
|
+
"ServiceTaskDefinitionAssertionError",
|
|
16
|
+
"WaitForDeploymentError",
|
|
17
|
+
)
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from operator import itemgetter
|
|
5
|
+
from time import sleep
|
|
6
|
+
from typing import TYPE_CHECKING, Optional
|
|
7
|
+
|
|
8
|
+
import boto3
|
|
9
|
+
import botocore.exceptions
|
|
10
|
+
from pydantic import PositiveInt, validate_call
|
|
11
|
+
|
|
12
|
+
from .errors import DeploymentFailedError, NoRunningDeploymentError, ServiceTaskDefinitionAssertionError
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from .common import ECSServiceRef
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ECSDeploymentWaiter:
|
|
21
|
+
"""ECS service deployment waiter."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, service_ref: ECSServiceRef, *, session: boto3.session.Session | None = None) -> None:
|
|
24
|
+
"""Initialize instance.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
service_ref: Reference to the ECS service.
|
|
28
|
+
session: Boto3 session to use for AWS operations.
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
self.service_ref = service_ref
|
|
32
|
+
self.session = session or boto3.session.Session()
|
|
33
|
+
|
|
34
|
+
@validate_call
|
|
35
|
+
def wait(
|
|
36
|
+
self,
|
|
37
|
+
*,
|
|
38
|
+
wait_for_start: bool,
|
|
39
|
+
polling_interval: PositiveInt = 5,
|
|
40
|
+
wait_for_stability: bool,
|
|
41
|
+
expected_task_definition: Optional[str] = None,
|
|
42
|
+
) -> None:
|
|
43
|
+
"""Wait for the ECS deployment to complete.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
wait_for_start: Whether to wait for the deployment to start.
|
|
47
|
+
polling_interval: The interval between any polling attempts, in seconds.
|
|
48
|
+
wait_for_stability: Whether to wait for the service to be stable after the deployment.
|
|
49
|
+
expected_task_definition: The service's task definition expected after deployment.
|
|
50
|
+
"""
|
|
51
|
+
# Find current deployment for the service
|
|
52
|
+
logger.info(
|
|
53
|
+
"Looking up running deployment for service %s",
|
|
54
|
+
self.service_ref.service,
|
|
55
|
+
)
|
|
56
|
+
latest_deployment_arn = self.get_latest_deployment_arn(
|
|
57
|
+
wait_for_start=wait_for_start,
|
|
58
|
+
polling_interval=polling_interval,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Polling for the deployment to finish (successfully or unsuccessfully)
|
|
62
|
+
logger.info(
|
|
63
|
+
"Start waiting for deployment %s to finish.",
|
|
64
|
+
latest_deployment_arn,
|
|
65
|
+
)
|
|
66
|
+
ok, status = self.wait_for_deployment_complete(latest_deployment_arn, polling_interval=polling_interval)
|
|
67
|
+
if ok:
|
|
68
|
+
logger.info(
|
|
69
|
+
"Deployment succeeded with status %s",
|
|
70
|
+
status,
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
msg = f"Deployment failed with status: {status}"
|
|
74
|
+
raise DeploymentFailedError(msg)
|
|
75
|
+
|
|
76
|
+
# Wait for the service to be stable
|
|
77
|
+
if wait_for_stability:
|
|
78
|
+
logger.debug(
|
|
79
|
+
"Start waiting for service %s to be stable.",
|
|
80
|
+
self.service_ref.service,
|
|
81
|
+
)
|
|
82
|
+
self.wait_for_service_stability(polling_interval=polling_interval)
|
|
83
|
+
|
|
84
|
+
# Check if the service task definition matches the expected one
|
|
85
|
+
if expected_task_definition:
|
|
86
|
+
logger.info(
|
|
87
|
+
"Checking if the service task definition is the expected one: %s",
|
|
88
|
+
expected_task_definition,
|
|
89
|
+
)
|
|
90
|
+
ok, actual = self.check_service_task_definition_is(expect=expected_task_definition)
|
|
91
|
+
if not ok:
|
|
92
|
+
msg = f"The service task definition is not the expected one; got: {actual!r}"
|
|
93
|
+
raise ServiceTaskDefinitionAssertionError(msg)
|
|
94
|
+
|
|
95
|
+
logger.info("The service task definition matches the expected one.")
|
|
96
|
+
|
|
97
|
+
@validate_call
|
|
98
|
+
def get_latest_deployment_arn(
|
|
99
|
+
self,
|
|
100
|
+
*,
|
|
101
|
+
wait_for_start: bool,
|
|
102
|
+
polling_interval: PositiveInt,
|
|
103
|
+
max_attempts: Optional[PositiveInt] = None,
|
|
104
|
+
) -> str:
|
|
105
|
+
"""Get the most recently started deployment ARN for the service.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
wait_for_start: Whether to wait for the deployment to start.
|
|
109
|
+
polling_interval: The interval between any polling attempts, in seconds.
|
|
110
|
+
max_attempts: The maximum number of attempts to wait for the deployment to start.
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
NoRunningDeploymentError: If no running deployments are found and `wait_for_start` is False.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
The ARN of the latest deployment for the service.
|
|
117
|
+
"""
|
|
118
|
+
ecs = self.session.client("ecs")
|
|
119
|
+
if wait_for_start:
|
|
120
|
+
logger.warning("`wait_for_start` is set, will wait for a new deployment to start.")
|
|
121
|
+
|
|
122
|
+
attempts = 0
|
|
123
|
+
while True: # do-while
|
|
124
|
+
# Do
|
|
125
|
+
running_deployments = ecs.list_service_deployments(
|
|
126
|
+
cluster=self.service_ref.cluster,
|
|
127
|
+
service=self.service_ref.service,
|
|
128
|
+
status=["PENDING", "IN_PROGRESS"],
|
|
129
|
+
)["serviceDeployments"]
|
|
130
|
+
|
|
131
|
+
# While
|
|
132
|
+
if running_deployments:
|
|
133
|
+
logger.debug("Found %d running deployments for service. Exiting loop.", len(running_deployments))
|
|
134
|
+
break
|
|
135
|
+
|
|
136
|
+
if not wait_for_start:
|
|
137
|
+
logger.debug("`wait_for_start` is off, no need to wait for a new deployment to start.")
|
|
138
|
+
break
|
|
139
|
+
|
|
140
|
+
if max_attempts and attempts >= max_attempts:
|
|
141
|
+
logger.debug("Max attempts exceeded while waiting for a new deployment to start.")
|
|
142
|
+
break
|
|
143
|
+
|
|
144
|
+
logger.debug(
|
|
145
|
+
"(%d-th attempt) No running deployments found for service. Start waiting for a new deployment.",
|
|
146
|
+
attempts + 1,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
sleep(polling_interval)
|
|
150
|
+
attempts += 1
|
|
151
|
+
|
|
152
|
+
if not running_deployments:
|
|
153
|
+
msg = "No running deployments found for service."
|
|
154
|
+
raise NoRunningDeploymentError(msg)
|
|
155
|
+
|
|
156
|
+
latest_deployment = sorted(running_deployments, key=itemgetter("startedAt"))[-1]
|
|
157
|
+
if len(running_deployments) > 1:
|
|
158
|
+
logger.warning(
|
|
159
|
+
"%d running deployments found for service. Using most recently started deployment: %s",
|
|
160
|
+
len(running_deployments),
|
|
161
|
+
latest_deployment["serviceDeploymentArn"],
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return latest_deployment["serviceDeploymentArn"]
|
|
165
|
+
|
|
166
|
+
@validate_call
|
|
167
|
+
def wait_for_deployment_complete(
|
|
168
|
+
self,
|
|
169
|
+
deployment_arn: str,
|
|
170
|
+
*,
|
|
171
|
+
polling_interval: PositiveInt,
|
|
172
|
+
max_attempts: Optional[PositiveInt] = None,
|
|
173
|
+
) -> tuple[bool, str]:
|
|
174
|
+
"""Wait for the ECS deployment to complete.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
deployment_arn: The ARN of the deployment to wait for.
|
|
178
|
+
polling_interval: The interval between any polling attempts, in seconds.
|
|
179
|
+
max_attempts: The maximum number of attempts to wait for the deployment to complete.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
A tuple containing a boolean indicating whether the deployment succeeded and the status of the deployment.
|
|
183
|
+
"""
|
|
184
|
+
ecs = self.session.client("ecs")
|
|
185
|
+
|
|
186
|
+
attempts = 0
|
|
187
|
+
while (max_attempts is None) or (attempts <= max_attempts):
|
|
188
|
+
latest_deployment = ecs.describe_service_deployments(serviceDeploymentArns=[deployment_arn])[
|
|
189
|
+
"serviceDeployments"
|
|
190
|
+
][0]
|
|
191
|
+
status = latest_deployment["status"]
|
|
192
|
+
if status == "SUCCESSFUL":
|
|
193
|
+
return (True, status)
|
|
194
|
+
|
|
195
|
+
if status in ("PENDING", "IN_PROGRESS"):
|
|
196
|
+
logger.debug(
|
|
197
|
+
"(%d-th attempt) Deployment in progress... (%s)",
|
|
198
|
+
attempts + 1,
|
|
199
|
+
status,
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
break
|
|
203
|
+
|
|
204
|
+
sleep(polling_interval)
|
|
205
|
+
attempts += 1
|
|
206
|
+
|
|
207
|
+
return (False, status)
|
|
208
|
+
|
|
209
|
+
@validate_call
|
|
210
|
+
def wait_for_service_stability(
|
|
211
|
+
self,
|
|
212
|
+
*,
|
|
213
|
+
polling_interval: PositiveInt,
|
|
214
|
+
max_attempts: Optional[PositiveInt] = None,
|
|
215
|
+
) -> bool:
|
|
216
|
+
"""Wait for the ECS service to be stable.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
polling_interval: The interval between any polling attempts, in seconds.
|
|
220
|
+
max_attempts: The maximum number of attempts to wait for the service to be stable.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
A boolean indicating whether the service is stable.
|
|
224
|
+
"""
|
|
225
|
+
ecs = self.session.client("ecs")
|
|
226
|
+
|
|
227
|
+
# TODO(lasuillard): Likely to be a problem in some cases: https://github.com/boto/botocore/issues/3314
|
|
228
|
+
stability_waiter = ecs.get_waiter("services_stable")
|
|
229
|
+
|
|
230
|
+
attempts = 0
|
|
231
|
+
while (max_attempts is None) or (attempts <= max_attempts):
|
|
232
|
+
logger.debug(
|
|
233
|
+
"(%d-th attempt) Waiting for service %s to be stable...",
|
|
234
|
+
attempts + 1,
|
|
235
|
+
self.service_ref.service,
|
|
236
|
+
)
|
|
237
|
+
try:
|
|
238
|
+
stability_waiter.wait(
|
|
239
|
+
cluster=self.service_ref.cluster,
|
|
240
|
+
services=[self.service_ref.service],
|
|
241
|
+
WaiterConfig={"Delay": polling_interval, "MaxAttempts": 1},
|
|
242
|
+
)
|
|
243
|
+
except botocore.exceptions.WaiterError as err:
|
|
244
|
+
if err.kwargs["reason"] != "Max attempts exceeded":
|
|
245
|
+
raise
|
|
246
|
+
else:
|
|
247
|
+
return True
|
|
248
|
+
|
|
249
|
+
sleep(polling_interval)
|
|
250
|
+
attempts += 1
|
|
251
|
+
|
|
252
|
+
return False
|
|
253
|
+
|
|
254
|
+
@validate_call
|
|
255
|
+
def check_service_task_definition_is(self, expect: str) -> tuple[bool, str]:
|
|
256
|
+
"""Check the service's current task definition matches the expected one.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
expect: The ARN of expected task definition.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
A tuple containing a boolean indicating whether the task definition matches the expected one
|
|
263
|
+
and the current task definition ARN.
|
|
264
|
+
"""
|
|
265
|
+
ecs = self.session.client("ecs")
|
|
266
|
+
|
|
267
|
+
service_detail = ecs.describe_services(cluster=self.service_ref.cluster, services=[self.service_ref.service])[
|
|
268
|
+
"services"
|
|
269
|
+
][0]
|
|
270
|
+
current_task_definition_arn = service_detail["taskDefinition"]
|
|
271
|
+
if current_task_definition_arn != expect:
|
|
272
|
+
return (False, current_task_definition_arn)
|
|
273
|
+
|
|
274
|
+
return (True, current_task_definition_arn)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
class WaitForDeploymentError(Exception):
|
|
2
|
+
"""Base class for all deployment waiter errors."""
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class NoRunningDeploymentError(WaitForDeploymentError):
|
|
6
|
+
"""No running deployment found for the service."""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DeploymentFailedError(WaitForDeploymentError):
|
|
10
|
+
"""Deployment failed."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ServiceTaskDefinitionAssertionError(WaitForDeploymentError):
|
|
14
|
+
"""Service task definition does not match the expected one."""
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import configparser
|
|
4
|
+
import logging
|
|
4
5
|
from pathlib import Path # noqa: TC003
|
|
5
6
|
from typing import Optional
|
|
6
7
|
|
|
7
8
|
from pydantic import BaseModel, ConfigDict
|
|
8
9
|
|
|
9
|
-
|
|
10
|
-
# TODO(lasuillard): Put some logging
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class MfaConfig(BaseModel):
|
|
@@ -30,9 +30,12 @@ class MfaConfig(BaseModel):
|
|
|
30
30
|
with path.open("w") as f:
|
|
31
31
|
config_ini.write(f)
|
|
32
32
|
|
|
33
|
+
logger.debug("Saved config to %s with section %s", path, section_key)
|
|
34
|
+
|
|
33
35
|
@classmethod
|
|
34
36
|
def from_ini_file(cls, path: Path, section_key: str) -> tuple[MfaConfig, bool]:
|
|
35
37
|
"""Load configuration from an AWS config file, with boolean indicating if the config already exists."""
|
|
38
|
+
logger.debug("Loading config from %s with section %s", path, section_key)
|
|
36
39
|
config_ini = configparser.ConfigParser()
|
|
37
40
|
config_ini.read(path)
|
|
38
41
|
if config_ini.has_section(section_key):
|
|
@@ -52,3 +55,5 @@ def update_credentials(path: Path, profile: str, *, access_key: str, secret_key:
|
|
|
52
55
|
credentials_ini[profile]["aws_session_token"] = session_token
|
|
53
56
|
with path.open("w") as f:
|
|
54
57
|
credentials_ini.write(f)
|
|
58
|
+
|
|
59
|
+
logger.debug("Updated credentials file %s with profile %s", path, profile)
|
|
@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, NamedTuple
|
|
|
13
13
|
|
|
14
14
|
import boto3
|
|
15
15
|
|
|
16
|
-
from aws_annoying.utils.platform import command_as_root, is_root, os_release
|
|
16
|
+
from aws_annoying.utils.platform import command_as_root, is_root, is_windows, os_release
|
|
17
17
|
|
|
18
18
|
from .errors import PluginNotInstalledError, UnsupportedPlatformError
|
|
19
19
|
|
|
@@ -22,8 +22,6 @@ if TYPE_CHECKING:
|
|
|
22
22
|
|
|
23
23
|
logger = logging.getLogger(__name__)
|
|
24
24
|
|
|
25
|
-
# TODO(lasuillard): Platform checking is spread everywhere, should be moved to a single place
|
|
26
|
-
|
|
27
25
|
|
|
28
26
|
class SessionManager:
|
|
29
27
|
"""AWS Session Manager plugin manager."""
|
|
@@ -65,7 +63,7 @@ class SessionManager:
|
|
|
65
63
|
"""Get the path to the session-manager-plugin binary."""
|
|
66
64
|
binary_path_str = shutil.which("session-manager-plugin")
|
|
67
65
|
if not binary_path_str:
|
|
68
|
-
if
|
|
66
|
+
if is_windows():
|
|
69
67
|
# Windows: use the default installation path
|
|
70
68
|
binary_path = (
|
|
71
69
|
Path(os.environ["ProgramFiles"]) # noqa: SIM112
|
|
@@ -5,6 +5,8 @@ import subprocess
|
|
|
5
5
|
from contextlib import contextmanager
|
|
6
6
|
from typing import TYPE_CHECKING
|
|
7
7
|
|
|
8
|
+
from aws_annoying.utils.timeout import Timeout
|
|
9
|
+
|
|
8
10
|
from .session_manager import SessionManager
|
|
9
11
|
|
|
10
12
|
if TYPE_CHECKING:
|
|
@@ -15,13 +17,14 @@ logger = logging.getLogger(__name__)
|
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
@contextmanager
|
|
18
|
-
def port_forward(
|
|
20
|
+
def port_forward( # noqa: PLR0913
|
|
19
21
|
*,
|
|
20
22
|
through: str,
|
|
21
23
|
local_port: int,
|
|
22
24
|
remote_host: str,
|
|
23
25
|
remote_port: int,
|
|
24
26
|
reason: str | None = None,
|
|
27
|
+
start_timeout: int | None = None,
|
|
25
28
|
) -> Iterator[subprocess.Popen[str]]:
|
|
26
29
|
"""Context manager for port forwarding sessions.
|
|
27
30
|
|
|
@@ -31,6 +34,7 @@ def port_forward(
|
|
|
31
34
|
remote_host: The remote host to connect to.
|
|
32
35
|
remote_port: The remote port to connect to.
|
|
33
36
|
reason: The reason for starting the session.
|
|
37
|
+
start_timeout: The timeout in seconds to wait for the session to start.
|
|
34
38
|
|
|
35
39
|
Returns:
|
|
36
40
|
The command to start the session.
|
|
@@ -61,11 +65,11 @@ def port_forward(
|
|
|
61
65
|
|
|
62
66
|
# Wait for the session to start
|
|
63
67
|
# ? Not sure this is trustworthy health check
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
68
|
+
with Timeout(start_timeout):
|
|
69
|
+
for line in proc.stdout:
|
|
70
|
+
if "Waiting for connections..." in line:
|
|
71
|
+
logger.info("Session started successfully.")
|
|
72
|
+
break
|
|
69
73
|
|
|
70
74
|
yield proc
|
|
71
75
|
finally:
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
|
|
5
|
+
import boto3
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_instance_id_by_name(name_or_id: str, *, session: boto3.session.Session | None = None) -> str | None:
|
|
9
|
+
"""Get the EC2 instance ID by name or ID.
|
|
10
|
+
|
|
11
|
+
Be aware that this function will only return the first instance found
|
|
12
|
+
with the given name, no matter how many instances are found.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
name_or_id: The name or ID of the EC2 instance.
|
|
16
|
+
session: The boto3 session to use. If not provided, a new session will be created.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
The instance ID if found, otherwise `None`.
|
|
20
|
+
"""
|
|
21
|
+
if re.match(r"^m?i-[0-9a-f]+$", name_or_id):
|
|
22
|
+
return name_or_id
|
|
23
|
+
|
|
24
|
+
session = session or boto3.session.Session()
|
|
25
|
+
ec2 = session.client("ec2")
|
|
26
|
+
|
|
27
|
+
response = ec2.describe_instances(Filters=[{"Name": "tag:Name", "Values": [name_or_id]}])
|
|
28
|
+
reservations = response["Reservations"]
|
|
29
|
+
if not reservations:
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
instances = reservations[0]["Instances"]
|
|
33
|
+
if not instances:
|
|
34
|
+
return None
|
|
35
|
+
|
|
36
|
+
return str(instances[0]["InstanceId"])
|
aws_annoying/utils/platform.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
+
import platform
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
|
|
6
7
|
|
|
@@ -25,3 +26,13 @@ def os_release() -> dict[str, str]:
|
|
|
25
26
|
key.strip('"'): value.strip('"')
|
|
26
27
|
for key, value in (line.split("=", 1) for line in content.splitlines() if "=" in line)
|
|
27
28
|
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def is_macos() -> bool:
|
|
32
|
+
"""Check if the current OS is macOS."""
|
|
33
|
+
return platform.system() == "Darwin"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def is_windows() -> bool:
|
|
37
|
+
"""Check if the current OS is Windows."""
|
|
38
|
+
return platform.system() == "Windows"
|