dbt-platform-helper 12.2.3__py3-none-any.whl → 12.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dbt-platform-helper might be problematic. Click here for more details.

@@ -5,11 +5,7 @@ from botocore.exceptions import ClientError
5
5
 
6
6
  from dbt_platform_helper.constants import CONDUIT_DOCKER_IMAGE_LOCATION
7
7
  from dbt_platform_helper.exceptions import CreateTaskTimeoutError
8
- from dbt_platform_helper.providers.ecs import get_ecs_task_arns
9
- from dbt_platform_helper.providers.secrets import get_connection_secret_arn
10
- from dbt_platform_helper.providers.secrets import (
11
- get_postgres_connection_data_updated_with_master_secret,
12
- )
8
+ from dbt_platform_helper.providers.secrets import Secrets
13
9
  from dbt_platform_helper.utils.application import Application
14
10
  from dbt_platform_helper.utils.messages import abort_with_error
15
11
 
@@ -59,7 +55,7 @@ def create_addon_client_task(
59
55
  # We cannot check for botocore.errorfactory.NoSuchEntityException as botocore generates that class on the fly as part of errorfactory.
60
56
  # factory. Checking the error code is the recommended way of handling these exceptions.
61
57
  if ex.response.get("Error", {}).get("Code", None) != "NoSuchEntity":
62
- # TODO Raise an exception to be caught at the command layer
58
+ # TODO When we are refactoring this, raise an exception to be caught at the command layer
63
59
  abort_with_error(
64
60
  f"cannot obtain Role {role_name}: {ex.response.get('Error', {}).get('Message', '')}"
65
61
  )
@@ -69,7 +65,7 @@ def create_addon_client_task(
69
65
  f"--task-group-name {task_name} "
70
66
  f"{execution_role}"
71
67
  f"--image {CONDUIT_DOCKER_IMAGE_LOCATION}:{addon_type} "
72
- f"--secrets CONNECTION_SECRET={get_connection_secret_arn(ssm_client,secrets_manager_client, secret_name)} "
68
+ f"--secrets CONNECTION_SECRET={_get_secrets_provider(application, env).get_connection_secret_arn(secret_name)} "
73
69
  "--platform-os linux "
74
70
  "--platform-arch arm64",
75
71
  shell=True,
@@ -95,8 +91,8 @@ def create_postgres_admin_task(
95
91
  "Parameter"
96
92
  ]["Value"]
97
93
  connection_string = json.dumps(
98
- get_postgres_connection_data_updated_with_master_secret(
99
- ssm_client, secrets_manager_client, read_only_secret_name, master_secret_arn
94
+ _get_secrets_provider(app, env).get_postgres_connection_data_updated_with_master_secret(
95
+ read_only_secret_name, master_secret_arn
100
96
  )
101
97
  )
102
98
 
@@ -111,6 +107,19 @@ def create_postgres_admin_task(
111
107
  )
112
108
 
113
109
 
110
+ def _temp_until_refactor_get_ecs_task_arns(ecs_client, cluster_arn: str, task_name: str):
111
+ tasks = ecs_client.list_tasks(
112
+ cluster=cluster_arn,
113
+ desiredStatus="RUNNING",
114
+ family=f"copilot-{task_name}",
115
+ )
116
+
117
+ if not tasks["taskArns"]:
118
+ return []
119
+
120
+ return tasks["taskArns"]
121
+
122
+
114
123
  def connect_to_addon_client_task(
115
124
  ecs_client,
116
125
  subprocess,
@@ -118,13 +127,14 @@ def connect_to_addon_client_task(
118
127
  env,
119
128
  cluster_arn,
120
129
  task_name,
121
- addon_client_is_running_fn=get_ecs_task_arns,
130
+ get_ecs_task_arns=_temp_until_refactor_get_ecs_task_arns,
122
131
  ):
123
132
  running = False
124
133
  tries = 0
125
134
  while tries < 15 and not running:
126
135
  tries += 1
127
- if addon_client_is_running_fn(ecs_client, cluster_arn, task_name):
136
+ # Todo: Use from ECS provider when we refactor this
137
+ if get_ecs_task_arns(ecs_client, cluster_arn, task_name):
128
138
  subprocess.call(
129
139
  "copilot task exec "
130
140
  f"--app {application_name} --env {env} "
@@ -137,8 +147,18 @@ def connect_to_addon_client_task(
137
147
  time.sleep(1)
138
148
 
139
149
  if not running:
140
- raise CreateTaskTimeoutError
150
+ raise CreateTaskTimeoutError(task_name, application_name, env)
141
151
 
142
152
 
143
153
  def _normalise_secret_name(addon_name: str) -> str:
144
154
  return addon_name.replace("-", "_").upper()
155
+
156
+
157
+ def _get_secrets_provider(application: Application, env: str) -> Secrets:
158
+ # Todo: We instantiate the secrets provider here to avoid rabbit holing, but something better probably possible when we are refactoring this area
159
+ return Secrets(
160
+ application.environments[env].session.client("ssm"),
161
+ application.environments[env].session.client("secretsmanager"),
162
+ application.name,
163
+ env,
164
+ )
@@ -7,73 +7,81 @@ from dbt_platform_helper.exceptions import ECSAgentNotRunning
7
7
  from dbt_platform_helper.exceptions import NoClusterError
8
8
 
9
9
 
10
- # Todo: Refactor to a class, review, then perhaps do the others
11
- def get_cluster_arn(ecs_client, application_name: str, env: str) -> str:
12
- for cluster_arn in ecs_client.list_clusters()["clusterArns"]:
13
- tags_response = ecs_client.list_tags_for_resource(resourceArn=cluster_arn)
14
- tags = tags_response["tags"]
15
-
16
- app_key_found = False
17
- env_key_found = False
18
- cluster_key_found = False
19
-
20
- for tag in tags:
21
- if tag["key"] == "copilot-application" and tag["value"] == application_name:
22
- app_key_found = True
23
- if tag["key"] == "copilot-environment" and tag["value"] == env:
24
- env_key_found = True
25
- if tag["key"] == "aws:cloudformation:logical-id" and tag["value"] == "Cluster":
26
- cluster_key_found = True
27
-
28
- if app_key_found and env_key_found and cluster_key_found:
29
- return cluster_arn
30
-
31
- raise NoClusterError
32
-
33
-
34
- def get_or_create_task_name(
35
- ssm_client, application_name: str, env: str, addon_name: str, parameter_name: str
36
- ) -> str:
37
- try:
38
- return ssm_client.get_parameter(Name=parameter_name)["Parameter"]["Value"]
39
- except ssm_client.exceptions.ParameterNotFound:
40
- random_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=12))
41
- return f"conduit-{application_name}-{env}-{addon_name}-{random_id}"
42
-
43
-
44
- def get_ecs_task_arns(ecs_client, cluster_arn: str, task_name: str):
45
-
46
- tasks = ecs_client.list_tasks(
47
- cluster=cluster_arn,
48
- desiredStatus="RUNNING",
49
- family=f"copilot-{task_name}",
50
- )
51
-
52
- if not tasks["taskArns"]:
53
- return []
54
-
55
- return tasks["taskArns"]
56
-
57
-
58
- def ecs_exec_is_available(ecs_client, cluster_arn: str, task_arns: List[str]):
59
-
60
- current_attemps = 0
61
- execute_command_agent_status = ""
62
-
63
- while execute_command_agent_status != "RUNNING" and current_attemps < 25:
64
-
65
- current_attemps += 1
66
-
67
- task_details = ecs_client.describe_tasks(cluster=cluster_arn, tasks=task_arns)
68
-
69
- managed_agents = task_details["tasks"][0]["containers"][0]["managedAgents"]
70
- execute_command_agent_status = [
71
- agent["lastStatus"]
72
- for agent in managed_agents
73
- if agent["name"] == "ExecuteCommandAgent"
74
- ][0]
75
-
76
- time.sleep(1)
77
-
78
- if execute_command_agent_status != "RUNNING":
79
- raise ECSAgentNotRunning
10
+ class ECS:
11
+ def __init__(self, ecs_client, ssm_client, application_name: str, env: str):
12
+ self.ecs_client = ecs_client
13
+ self.ssm_client = ssm_client
14
+ self.application_name = application_name
15
+ self.env = env
16
+
17
+ def get_cluster_arn(self) -> str:
18
+ """Returns the ARN of the ECS cluster for the given application and
19
+ environment."""
20
+ for cluster_arn in self.ecs_client.list_clusters()["clusterArns"]:
21
+ tags_response = self.ecs_client.list_tags_for_resource(resourceArn=cluster_arn)
22
+ tags = tags_response["tags"]
23
+
24
+ app_key_found = False
25
+ env_key_found = False
26
+ cluster_key_found = False
27
+
28
+ for tag in tags:
29
+ if tag["key"] == "copilot-application" and tag["value"] == self.application_name:
30
+ app_key_found = True
31
+ if tag["key"] == "copilot-environment" and tag["value"] == self.env:
32
+ env_key_found = True
33
+ if tag["key"] == "aws:cloudformation:logical-id" and tag["value"] == "Cluster":
34
+ cluster_key_found = True
35
+
36
+ if app_key_found and env_key_found and cluster_key_found:
37
+ return cluster_arn
38
+
39
+ raise NoClusterError(self.application_name, self.env)
40
+
41
+ def get_or_create_task_name(self, addon_name: str, parameter_name: str) -> str:
42
+ """Fetches the task name from SSM or creates a new one if not found."""
43
+ try:
44
+ return self.ssm_client.get_parameter(Name=parameter_name)["Parameter"]["Value"]
45
+ except self.ssm_client.exceptions.ParameterNotFound:
46
+ random_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=12))
47
+ return f"conduit-{self.application_name}-{self.env}-{addon_name}-{random_id}"
48
+
49
+ def get_ecs_task_arns(self, cluster_arn: str, task_name: str):
50
+ """Gets the ECS task ARNs for a given task name and cluster ARN."""
51
+ tasks = self.ecs_client.list_tasks(
52
+ cluster=cluster_arn,
53
+ desiredStatus="RUNNING",
54
+ family=f"copilot-{task_name}",
55
+ )
56
+
57
+ if not tasks["taskArns"]:
58
+ return []
59
+
60
+ return tasks["taskArns"]
61
+
62
+ def ecs_exec_is_available(self, cluster_arn: str, task_arns: List[str]):
63
+ """
64
+ Checks if the ExecuteCommandAgent is running on the specified ECS task.
65
+
66
+ Waits for up to 25 attempts, then raises ECSAgentNotRunning if still not
67
+ running.
68
+ """
69
+ current_attempts = 0
70
+ execute_command_agent_status = ""
71
+
72
+ while execute_command_agent_status != "RUNNING" and current_attempts < 25:
73
+ current_attempts += 1
74
+
75
+ task_details = self.ecs_client.describe_tasks(cluster=cluster_arn, tasks=task_arns)
76
+
77
+ managed_agents = task_details["tasks"][0]["containers"][0]["managedAgents"]
78
+ execute_command_agent_status = [
79
+ agent["lastStatus"]
80
+ for agent in managed_agents
81
+ if agent["name"] == "ExecuteCommandAgent"
82
+ ][0]
83
+ if execute_command_agent_status != "RUNNING":
84
+ time.sleep(1)
85
+
86
+ if execute_command_agent_status != "RUNNING":
87
+ raise ECSAgentNotRunning
@@ -9,77 +9,77 @@ from dbt_platform_helper.exceptions import ParameterNotFoundError
9
9
  from dbt_platform_helper.exceptions import SecretNotFoundError
10
10
 
11
11
 
12
- def get_postgres_connection_data_updated_with_master_secret(
13
- ssm_client, secrets_manager_client, parameter_name, secret_arn
14
- ):
15
- response = ssm_client.get_parameter(Name=parameter_name, WithDecryption=True)
16
- parameter_value = response["Parameter"]["Value"]
17
-
18
- parameter_data = json.loads(parameter_value)
19
-
20
- secret_response = secrets_manager_client.get_secret_value(SecretId=secret_arn)
21
- secret_value = json.loads(secret_response["SecretString"])
22
-
23
- parameter_data["username"] = urllib.parse.quote(secret_value["username"])
24
- parameter_data["password"] = urllib.parse.quote(secret_value["password"])
25
-
26
- return parameter_data
27
-
28
-
29
- def get_connection_secret_arn(ssm_client, secrets_manager_client, secret_name: str) -> str:
30
-
31
- try:
32
- return ssm_client.get_parameter(Name=secret_name, WithDecryption=False)["Parameter"]["ARN"]
33
- except ssm_client.exceptions.ParameterNotFound:
34
- pass
35
-
36
- try:
37
- return secrets_manager_client.describe_secret(SecretId=secret_name)["ARN"]
38
- except secrets_manager_client.exceptions.ResourceNotFoundException:
39
- pass
40
-
41
- raise SecretNotFoundError(secret_name)
42
-
43
-
44
- def get_addon_type(ssm_client, application_name: str, env: str, addon_name: str) -> str:
45
- addon_type = None
46
- try:
47
- addon_config = json.loads(
48
- ssm_client.get_parameter(
49
- Name=f"/copilot/applications/{application_name}/environments/{env}/addons"
50
- )["Parameter"]["Value"]
51
- )
52
- except ssm_client.exceptions.ParameterNotFound:
53
- raise ParameterNotFoundError
54
-
55
- if addon_name not in addon_config.keys():
56
- raise AddonNotFoundError
57
-
58
- for name, config in addon_config.items():
59
- if name == addon_name:
60
- if not config.get("type"):
61
- raise AddonTypeMissingFromConfigError()
62
- addon_type = config["type"]
63
-
64
- if not addon_type or addon_type not in CONDUIT_ADDON_TYPES:
65
- raise InvalidAddonTypeError(addon_type)
66
-
67
- if "postgres" in addon_type:
68
- addon_type = "postgres"
69
-
70
- return addon_type
71
-
72
-
73
- def get_parameter_name(
74
- application_name: str, env: str, addon_type: str, addon_name: str, access: str
75
- ) -> str:
76
- if addon_type == "postgres":
77
- return f"/copilot/{application_name}/{env}/conduits/{_normalise_secret_name(addon_name)}_{access.upper()}"
78
- elif addon_type == "redis" or addon_type == "opensearch":
79
- return f"/copilot/{application_name}/{env}/conduits/{_normalise_secret_name(addon_name)}_ENDPOINT"
80
- else:
81
- return f"/copilot/{application_name}/{env}/conduits/{_normalise_secret_name(addon_name)}"
82
-
83
-
84
- def _normalise_secret_name(addon_name: str) -> str:
85
- return addon_name.replace("-", "_").upper()
12
+ class Secrets:
13
+ def __init__(self, ssm_client, secrets_manager_client, application_name, env):
14
+ self.ssm_client = ssm_client
15
+ self.secrets_manager_client = secrets_manager_client
16
+ self.application_name = application_name
17
+ self.env = env
18
+
19
+ def get_postgres_connection_data_updated_with_master_secret(self, parameter_name, secret_arn):
20
+ response = self.ssm_client.get_parameter(Name=parameter_name, WithDecryption=True)
21
+ parameter_value = response["Parameter"]["Value"]
22
+
23
+ parameter_data = json.loads(parameter_value)
24
+
25
+ secret_response = self.secrets_manager_client.get_secret_value(SecretId=secret_arn)
26
+ secret_value = json.loads(secret_response["SecretString"])
27
+
28
+ parameter_data["username"] = urllib.parse.quote(secret_value["username"])
29
+ parameter_data["password"] = urllib.parse.quote(secret_value["password"])
30
+
31
+ return parameter_data
32
+
33
+ def get_connection_secret_arn(self, secret_name: str) -> str:
34
+ try:
35
+ return self.ssm_client.get_parameter(Name=secret_name, WithDecryption=False)[
36
+ "Parameter"
37
+ ]["ARN"]
38
+ except self.ssm_client.exceptions.ParameterNotFound:
39
+ pass
40
+
41
+ try:
42
+ return self.secrets_manager_client.describe_secret(SecretId=secret_name)["ARN"]
43
+ except self.secrets_manager_client.exceptions.ResourceNotFoundException:
44
+ pass
45
+
46
+ raise SecretNotFoundError(secret_name)
47
+
48
+ def get_addon_type(self, addon_name: str) -> str:
49
+ addon_type = None
50
+ try:
51
+ addon_config = json.loads(
52
+ self.ssm_client.get_parameter(
53
+ Name=f"/copilot/applications/{self.application_name}/environments/{self.env}/addons"
54
+ )["Parameter"]["Value"]
55
+ )
56
+ except self.ssm_client.exceptions.ParameterNotFound:
57
+ raise ParameterNotFoundError(self.application_name, self.env)
58
+
59
+ if addon_name not in addon_config.keys():
60
+ raise AddonNotFoundError(addon_name)
61
+
62
+ for name, config in addon_config.items():
63
+ if name == addon_name:
64
+ if not config.get("type"):
65
+ raise AddonTypeMissingFromConfigError(addon_name)
66
+ addon_type = config["type"]
67
+
68
+ if not addon_type or addon_type not in CONDUIT_ADDON_TYPES:
69
+ raise InvalidAddonTypeError(addon_type)
70
+
71
+ if "postgres" in addon_type:
72
+ addon_type = "postgres"
73
+
74
+ return addon_type
75
+
76
+ def get_parameter_name(self, addon_type: str, addon_name: str, access: str) -> str:
77
+ if addon_type == "postgres":
78
+ return f"/copilot/{self.application_name}/{self.env}/conduits/{self._normalise_secret_name(addon_name)}_{access.upper()}"
79
+ elif addon_type == "redis" or addon_type == "opensearch":
80
+ return f"/copilot/{self.application_name}/{self.env}/conduits/{self._normalise_secret_name(addon_name)}_ENDPOINT"
81
+ else:
82
+ return f"/copilot/{self.application_name}/{self.env}/conduits/{self._normalise_secret_name(addon_name)}"
83
+
84
+ def _normalise_secret_name(self, addon_name: str) -> str:
85
+ return addon_name.replace("-", "_").upper()
@@ -80,7 +80,7 @@ def load_application(app: str = None, default_session: Session = None) -> Applic
80
80
  WithDecryption=False,
81
81
  )
82
82
  except ssm_client.exceptions.ParameterNotFound:
83
- raise ApplicationNotFoundError
83
+ raise ApplicationNotFoundError(app)
84
84
 
85
85
  path = f"/copilot/applications/{application.name}/environments"
86
86
  secrets = get_ssm_secrets(app, None, current_session, path)
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import os
3
+ import time
3
4
  import urllib.parse
4
5
  from configparser import ConfigParser
5
6
  from pathlib import Path
@@ -15,6 +16,7 @@ from boto3 import Session
15
16
  from dbt_platform_helper.exceptions import AWSException
16
17
  from dbt_platform_helper.exceptions import CopilotCodebaseNotFoundError
17
18
  from dbt_platform_helper.exceptions import ImageNotFoundError
19
+ from dbt_platform_helper.exceptions import ResourceNotFoundException
18
20
  from dbt_platform_helper.exceptions import ValidationException
19
21
  from dbt_platform_helper.utils.files import cache_refresh_required
20
22
  from dbt_platform_helper.utils.files import read_supported_versions_from_cache
@@ -340,6 +342,7 @@ def get_load_balancer_configuration(
340
342
 
341
343
 
342
344
  def get_postgres_connection_data_updated_with_master_secret(session, parameter_name, secret_arn):
345
+ # Todo: This is pretty much the same as dbt_platform_helper.providers.secrets.Secrets.get_postgres_connection_data_updated_with_master_secret
343
346
  ssm_client = session.client("ssm")
344
347
  secrets_manager_client = session.client("secretsmanager")
345
348
  response = ssm_client.get_parameter(Name=parameter_name, WithDecryption=True)
@@ -414,7 +417,7 @@ def get_connection_string(
414
417
  app: str,
415
418
  env: str,
416
419
  db_identifier: str,
417
- connection_data_fn=get_postgres_connection_data_updated_with_master_secret,
420
+ connection_data=get_postgres_connection_data_updated_with_master_secret,
418
421
  ) -> str:
419
422
  addon_name = db_identifier.split(f"{app}-{env}-", 1)[1]
420
423
  normalised_addon_name = addon_name.replace("-", "_").upper()
@@ -426,7 +429,7 @@ def get_connection_string(
426
429
  Name=master_secret_name, WithDecryption=True
427
430
  )["Parameter"]["Value"]
428
431
 
429
- conn = connection_data_fn(session, connection_string_parameter, master_secret_arn)
432
+ conn = connection_data(session, connection_string_parameter, master_secret_arn)
430
433
 
431
434
  return f"postgres://{conn['username']}:{conn['password']}@{conn['host']}:{conn['port']}/{conn['dbname']}"
432
435
 
@@ -499,7 +502,7 @@ def check_codebase_exists(session: Session, application, codebase: str):
499
502
  ssm_client.exceptions.ParameterNotFound,
500
503
  json.JSONDecodeError,
501
504
  ):
502
- raise CopilotCodebaseNotFoundError
505
+ raise CopilotCodebaseNotFoundError(codebase)
503
506
 
504
507
 
505
508
  def check_image_exists(session, application, codebase, commit):
@@ -513,7 +516,7 @@ def check_image_exists(session, application, codebase, commit):
513
516
  ecr_client.exceptions.RepositoryNotFoundException,
514
517
  ecr_client.exceptions.ImageNotFoundException,
515
518
  ):
516
- raise ImageNotFoundError
519
+ raise ImageNotFoundError(commit)
517
520
 
518
521
 
519
522
  def get_build_url_from_arn(build_arn: str) -> str:
@@ -525,7 +528,7 @@ def get_build_url_from_arn(build_arn: str) -> str:
525
528
  )
526
529
 
527
530
 
528
- def list_latest_images(ecr_client, ecr_repository_name, codebase_repository, echo_fn):
531
+ def list_latest_images(ecr_client, ecr_repository_name, codebase_repository, echo):
529
532
  paginator = ecr_client.get_paginator("describe_images")
530
533
  describe_images_response_iterator = paginator.paginate(
531
534
  repositoryName=ecr_repository_name,
@@ -550,8 +553,28 @@ def list_latest_images(ecr_client, ecr_repository_name, codebase_repository, ech
550
553
  continue
551
554
 
552
555
  commit_hash = commit_tag.replace("commit-", "")
553
- echo_fn(
556
+ echo(
554
557
  f" - https://github.com/{codebase_repository}/commit/{commit_hash} - published: {image['imagePushedAt']}"
555
558
  )
556
559
  except StopIteration:
557
560
  continue
561
+
562
+
563
+ def wait_for_log_group_to_exist(log_client, log_group_name, attempts=30):
564
+ current_attempts = 0
565
+ log_group_exists = False
566
+
567
+ while not log_group_exists and current_attempts < attempts:
568
+ current_attempts += 1
569
+
570
+ log_group_response = log_client.describe_log_groups(logGroupNamePrefix=log_group_name)
571
+ log_groups = log_group_response.get("logGroups", [])
572
+
573
+ for group in log_groups:
574
+ if group["logGroupName"] == log_group_name:
575
+ log_group_exists = True
576
+
577
+ time.sleep(1)
578
+
579
+ if not log_group_exists:
580
+ raise ResourceNotFoundException
@@ -104,13 +104,13 @@ def validate_addons(addons: dict):
104
104
  config={"extensions": addons},
105
105
  extension_type="redis",
106
106
  version_key="engine",
107
- get_supported_versions_fn=get_supported_redis_versions,
107
+ get_supported_versions=get_supported_redis_versions,
108
108
  )
109
109
  _validate_extension_supported_versions(
110
110
  config={"extensions": addons},
111
111
  extension_type="opensearch",
112
112
  version_key="engine",
113
- get_supported_versions_fn=get_supported_opensearch_versions,
113
+ get_supported_versions=get_supported_opensearch_versions,
114
114
  )
115
115
 
116
116
  return errors
@@ -210,7 +210,12 @@ RETENTION_POLICY = Or(
210
210
  },
211
211
  )
212
212
 
213
- DATABASE_COPY = {"from": ENV_NAME, "to": ENV_NAME}
213
+ DATABASE_COPY = {
214
+ "from": ENV_NAME,
215
+ "to": ENV_NAME,
216
+ Optional("from_account"): str,
217
+ Optional("to_account"): str,
218
+ }
214
219
 
215
220
  POSTGRES_DEFINITION = {
216
221
  "type": "postgres",
@@ -558,18 +563,18 @@ def validate_platform_config(config):
558
563
  config=config,
559
564
  extension_type="redis",
560
565
  version_key="engine",
561
- get_supported_versions_fn=get_supported_redis_versions,
566
+ get_supported_versions=get_supported_redis_versions,
562
567
  )
563
568
  _validate_extension_supported_versions(
564
569
  config=config,
565
570
  extension_type="opensearch",
566
571
  version_key="engine",
567
- get_supported_versions_fn=get_supported_opensearch_versions,
572
+ get_supported_versions=get_supported_opensearch_versions,
568
573
  )
569
574
 
570
575
 
571
576
  def _validate_extension_supported_versions(
572
- config, extension_type, version_key, get_supported_versions_fn
577
+ config, extension_type, version_key, get_supported_versions
573
578
  ):
574
579
  extensions = config.get("extensions", {})
575
580
  if not extensions:
@@ -581,7 +586,7 @@ def _validate_extension_supported_versions(
581
586
  if extension.get("type") == extension_type
582
587
  ]
583
588
 
584
- supported_extension_versions = get_supported_versions_fn()
589
+ supported_extension_versions = get_supported_versions()
585
590
  extensions_with_invalid_version = []
586
591
 
587
592
  for extension in extensions_for_type:
@@ -590,13 +595,16 @@ def _validate_extension_supported_versions(
590
595
 
591
596
  if not isinstance(environments, dict):
592
597
  click.secho(
593
- "Error: Opensearch extension definition is invalid type, expected dictionary",
598
+ f"Error: {extension_type} extension definition is invalid type, expected dictionary",
594
599
  fg="red",
595
600
  )
596
601
  continue
597
602
  for environment, env_config in environments.items():
603
+
604
+ # An extension version doesn't need to be specified for all environments, provided one is specified under "*".
605
+ # So check if the version is set before checking if it's supported
598
606
  extension_version = env_config.get(version_key)
599
- if extension_version not in supported_extension_versions:
607
+ if extension_version and extension_version not in supported_extension_versions:
600
608
  extensions_with_invalid_version.append(
601
609
  {"environment": environment, "version": extension_version}
602
610
  )
@@ -635,6 +643,9 @@ def validate_database_copy_section(config):
635
643
  from_env = section["from"]
636
644
  to_env = section["to"]
637
645
 
646
+ from_account = _get_env_deploy_account_info(config, from_env, "id")
647
+ to_account = _get_env_deploy_account_info(config, to_env, "id")
648
+
638
649
  if from_env == to_env:
639
650
  errors.append(
640
651
  f"database_copy 'to' and 'from' cannot be the same environment in extension '{extension_name}'."
@@ -655,10 +666,33 @@ def validate_database_copy_section(config):
655
666
  f"database_copy 'to' parameter must be a valid environment ({all_envs_string}) but was '{to_env}' in extension '{extension_name}'."
656
667
  )
657
668
 
669
+ if from_account != to_account:
670
+ if "from_account" not in section:
671
+ errors.append(
672
+ f"Environments '{from_env}' and '{to_env}' are in different AWS accounts. The 'from_account' parameter must be present."
673
+ )
674
+ elif section["from_account"] != from_account:
675
+ errors.append(
676
+ f"Incorrect value for 'from_account' for environment '{from_env}'"
677
+ )
678
+
679
+ if "to_account" not in section:
680
+ errors.append(
681
+ f"Environments '{from_env}' and '{to_env}' are in different AWS accounts. The 'to_account' parameter must be present."
682
+ )
683
+ elif section["to_account"] != to_account:
684
+ errors.append(f"Incorrect value for 'to_account' for environment '{to_env}'")
685
+
658
686
  if errors:
659
687
  abort_with_error("\n".join(errors))
660
688
 
661
689
 
690
+ def _get_env_deploy_account_info(config, env, key):
691
+ return (
692
+ config.get("environments", {}).get(env, {}).get("accounts", {}).get("deploy", {}).get(key)
693
+ )
694
+
695
+
662
696
  def _validate_environment_pipelines(config):
663
697
  bad_pipelines = {}
664
698
  for pipeline_name, pipeline in config.get("environment_pipelines", {}).items():
@@ -666,13 +700,7 @@ def _validate_environment_pipelines(config):
666
700
  pipeline_account = pipeline.get("account", None)
667
701
  if pipeline_account:
668
702
  for env in pipeline.get("environments", {}).keys():
669
- env_account = (
670
- config.get("environments", {})
671
- .get(env, {})
672
- .get("accounts", {})
673
- .get("deploy", {})
674
- .get("name")
675
- )
703
+ env_account = _get_env_deploy_account_info(config, env, "name")
676
704
  if not env_account == pipeline_account:
677
705
  bad_envs.append(env)
678
706
  if bad_envs:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dbt-platform-helper
3
- Version: 12.2.3
3
+ Version: 12.3.0
4
4
  Summary: Set of tools to help transfer applications/services from GOV.UK PaaS to DBT PaaS augmenting AWS Copilot.
5
5
  License: MIT
6
6
  Author: Department for Business and Trade Platform Team