data-factory-utils 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,54 @@
1
+ Metadata-Version: 2.3
2
+ Name: data-factory-utils
3
+ Version: 0.3.0
4
+ Summary: Utility functions for interacting with data factories.
5
+ Requires-Dist: boto3>=1.42.8
6
+ Requires-Dist: boto3-stubs>=1.42.89
7
+ Requires-Dist: botocore>=1.42.7
8
+ Requires-Dist: cloudpathlib>=0.23.0
9
+ Requires-Dist: mypy-boto3>=1.42.3
10
+ Requires-Dist: mypy-boto3-athena>=1.42.43
11
+ Requires-Dist: mypy-boto3-s3>=1.42.85
12
+ Requires-Dist: mypy-boto3-sts>=1.42.3
13
+ Requires-Dist: polars>=1.40.0
14
+ Requires-Python: >=3.12
15
+ Description-Content-Type: text/markdown
16
+
17
+ # data-factory-utils
18
+ A package for random utils for data factories.
19
+
20
+ ## Installation
21
+ This is a published package. Install using your favourite installation method.
22
+
23
+ ```bash
24
+ uv add data-factory-utils
25
+ pip install data-factory-utils
26
+ ```
27
+
28
+ ## Usage
29
+ ### Environment functions
30
+ This set of functions reads from your data factory dynamically. It should infer the environment you are in as well.
31
+
32
+ No matter how many times you initiate the class, it will re-use old variables. To do so...
33
+ ```python
34
+ from data_factory_utils.environment import Environment
35
+ env = Environment()
36
+ ```
37
+
38
+ To return information about the environment (if we are in development with account number 0101010101):
39
+ ```python
40
+ env.account_no
41
+ # 0101010101
42
+ env.environment_name
43
+ # dev
44
+ env.is_prod
45
+ # False
46
+ ```
47
+
48
+ To get an S3 bucket name (outputted as `cloudpathlib`'s `S3Path`) (let us imagine here that the name is `emds-dev-random-name-202512161154001309058001`):
49
+
50
+ ```python
51
+ s3_random_name_bucket = env.get_full_bucket_url("random-name", full_prefix=True)
52
+ print(str(s3_random_name_bucket.bucket))
53
+ # emds-dev-random-name-202512161154001309058001
54
+ ```
@@ -0,0 +1,38 @@
1
+ # data-factory-utils
2
+ A package for random utils for data factories.
3
+
4
+ ## Installation
5
+ This is a published package. Install using your favourite installation method.
6
+
7
+ ```bash
8
+ uv add data-factory-utils
9
+ pip install data-factory-utils
10
+ ```
11
+
12
+ ## Usage
13
+ ### Environment functions
14
+ This set of functions reads from your data factory dynamically. It should infer the environment you are in as well.
15
+
16
+ No matter how many times you initiate the class, it will re-use old variables. To do so...
17
+ ```python
18
+ from data_factory_utils.environment import Environment
19
+ env = Environment()
20
+ ```
21
+
22
+ To return information about the environment (if we are in development with account number 0101010101):
23
+ ```python
24
+ env.account_no
25
+ # 0101010101
26
+ env.environment_name
27
+ # dev
28
+ env.is_prod
29
+ # False
30
+ ```
31
+
32
+ To get an S3 bucket name (outputted as `cloudpathlib`'s `S3Path`) (let us imagine here that the name is `emds-dev-random-name-202512161154001309058001`):
33
+
34
+ ```python
35
+ s3_random_name_bucket = env.get_full_bucket_url("random-name", full_prefix=True)
36
+ print(str(s3_random_name_bucket.bucket))
37
+ # emds-dev-random-name-202512161154001309058001
38
+ ```
@@ -0,0 +1,87 @@
1
+ [project]
2
+ name = "data-factory-utils"
3
+ version = "0.3.0"
4
+ description = "Utility functions for interacting with data factories."
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "boto3>=1.42.8",
9
+ "boto3-stubs>=1.42.89",
10
+ "botocore>=1.42.7",
11
+ "cloudpathlib>=0.23.0",
12
+ "mypy-boto3>=1.42.3",
13
+ "mypy-boto3-athena>=1.42.43",
14
+ "mypy-boto3-s3>=1.42.85",
15
+ "mypy-boto3-sts>=1.42.3",
16
+ "polars>=1.40.0",
17
+ ]
18
+
19
+ [build-system]
20
+ requires = ["uv_build>=0.9.17,<0.10.0"]
21
+ build-backend = "uv_build"
22
+
23
+ [dependency-groups]
24
+ dev = [
25
+ "mypy>=1.20.1",
26
+ "prek>=0.2.21",
27
+ "ruff>=0.14.8",
28
+ "toml-cli>=0.8.2",
29
+ "ty>=0.0.1a33",
30
+ ]
31
+ test = [
32
+ "moto>=5.1.18",
33
+ "pytest>=9.0.2",
34
+ ]
35
+
36
+
37
+ [tool.ruff]
38
+ line-length = 120
39
+
40
+ [tool.bandit]
41
+ exclude_dirs = ["/tests", "/.venv"]
42
+
43
+ [tool.mypy]
44
+ strict = true
45
+ namespace_packages = false
46
+ disallow_untyped_defs = true
47
+ follow_untyped_imports = true
48
+ exclude = ["tests"]
49
+
50
+ [tool.ruff.lint]
51
+ select = ["ALL"]
52
+ # Remove warnings
53
+ ignore = ["D203", "D213", "COM812"]
54
+
55
+ [tool.ruff.lint.per-file-ignores]
56
+ "tests/**.py" = ["S101"]
57
+
58
+ [tool.semantic_release]
59
+ commit_message = "{version}\n\nAutomatically generated by python-semantic-release"
60
+ commit_parser = "conventional"
61
+ logging_use_named_masks = false
62
+ major_on_zero = false
63
+ allow_zero_version = true
64
+ no_git_verify = false
65
+ tag_format = "{version}"
66
+
67
+ [tool.semantic_release.branches.main]
68
+ match = "main"
69
+ prerelease_token = "rc"
70
+ prerelease = false
71
+
72
+ [tool.semantic_release.branches.other]
73
+ match = ".*"
74
+ prerelease_token = "rc"
75
+ prerelease = true
76
+
77
+ [tool.semantic_release.commit_parser_options]
78
+ minor_tags = ["feat"]
79
+ patch_tags = ["fix", "perf"]
80
+ other_allowed_tags = ["build", "chore", "ci", "docs", "style", "refactor", "test"]
81
+ allowed_tags = ["feat", "fix", "perf", "build", "chore", "ci", "docs", "style", "refactor", "test"]
82
+ default_bump_level = 0
83
+ parse_squash_commits = true
84
+ ignore_merge_commits = true
85
+
86
+ [tool.bandit.assert_used]
87
+ skips = ['*_test.py', '*/test_*.py']
@@ -0,0 +1 @@
1
+ """Init file."""
@@ -0,0 +1,208 @@
1
+ """Athena helper functions."""
2
+
3
+ import logging
4
+ import time
5
+ from collections.abc import Generator
6
+
7
+ import boto3
8
+ import polars as pl
9
+ from mypy_boto3_athena.client import AthenaClient
10
+ from mypy_boto3_athena.type_defs import GetQueryResultsOutputTypeDef, QueryExecutionContextTypeDef
11
+
12
+ from data_factory_utils.environment import Environment
13
+ from data_factory_utils.query import Query
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class NoQueryIdError(Exception):
20
+ """Exception for no query id being returned."""
21
+
22
+
23
+ class AthenaQueryError(Exception):
24
+ """Exception for raising issues with Athena query execution.
25
+
26
+ boto3 itself doesn't provide a generic 'query failed' error, so we create our own to raise when queries fail.
27
+ """
28
+
29
+
30
+ class AthenaConfig:
31
+ """Container for Athena configuration.
32
+
33
+ It's slow to create so we do it once and share it between queries.
34
+ """
35
+
36
+ athena_client: AthenaClient
37
+ result_bucket: str
38
+ database: str
39
+ catalog: str | None
40
+ workgroup: str
41
+
42
+ def __init__(
43
+ self,
44
+ database: str,
45
+ catalog: str | None = None,
46
+ result_bucket_name: str = "athena-query-results",
47
+ environment: Environment | None = None,
48
+ ) -> None:
49
+ """Initialise class."""
50
+ if environment is None:
51
+ environment = Environment(use_web_identity=False)
52
+ self.result_bucket = str(environment.get_full_bucket_url(result_bucket_name, full_prefix=True))
53
+ self.athena_client = boto3.client("athena")
54
+ self.database = database
55
+ self.catalog = catalog
56
+ self.workgroup = f"{environment.account_number}-ears-sars"
57
+
58
+
59
+ class AthenaQuery:
60
+ """Abstraction of an Athena query.
61
+
62
+ boto3's API triggers the query, gets the query status and retrieves the response as separate calls.
63
+ AthenaQuery implements methods to simplify this flow.
64
+ """
65
+
66
+ query: Query
67
+ execution_id: str | None
68
+ config: AthenaConfig
69
+ has_succeeded: bool
70
+
71
+ def __init__(self, query: Query, config: AthenaConfig) -> None:
72
+ """Initialise class."""
73
+ self.query = query
74
+ self.execution_id = None
75
+ self.config = config
76
+ self.has_succeeded = False
77
+
78
+ def run_and_await(self) -> None:
79
+ """Run query, wait for it to complete."""
80
+ self.execution_id = self.start_execution()
81
+ self.await_query_completion()
82
+
83
+ def start_execution(self) -> str:
84
+ """Begin an Athena query with standard settings. Return the query ID."""
85
+ query_context: QueryExecutionContextTypeDef = {"Database": self.config.database}
86
+ if self.config.catalog is not None:
87
+ query_context["Catalog"] = self.config.catalog
88
+
89
+ query_string = str(self.query).strip()
90
+ if not query_string:
91
+ msg = "Query string cannot be empty"
92
+ raise ValueError(msg)
93
+
94
+ logger.debug(query_string)
95
+ response = self.config.athena_client.start_query_execution(
96
+ QueryString=str(self.query),
97
+ ResultConfiguration={"OutputLocation": self.config.result_bucket},
98
+ QueryExecutionContext=query_context,
99
+ WorkGroup=self.config.workgroup,
100
+ )
101
+
102
+ self.execution_id = str(response["QueryExecutionId"])
103
+ if not self.execution_id:
104
+ msg = "No Query Execution Id. Response: %s"
105
+ raise NoQueryIdError(msg, response)
106
+
107
+ logger.info("Query Execution ID: %s", self.execution_id)
108
+ return self.execution_id
109
+
110
+ def stop_execution(self) -> None:
111
+ """Stop an Athena query's execution."""
112
+ if self.execution_id is None:
113
+ msg = "No query id available. Make sure that the query has been startedbefore attempting to stop it."
114
+ raise AthenaQueryError(msg)
115
+
116
+ self.config.athena_client.stop_query_execution(QueryExecutionId=self.execution_id)
117
+
118
+ def await_query_completion(
119
+ self,
120
+ max_iterations: int = 20,
121
+ max_interval_seconds: int = 120,
122
+ ) -> None:
123
+ """Wait for a managed query with the given ID to finish and return the result."""
124
+ if self.execution_id is None:
125
+ msg = "No query id available. Make sure that the query has been startedbefore querying the result."
126
+ raise ValueError(msg)
127
+
128
+ for i in range(max_iterations):
129
+ # Exponential backoff on checking query result
130
+ wait_time = min(max_interval_seconds, 2**i)
131
+ time.sleep(wait_time)
132
+
133
+ # Check whether the query has terminated
134
+ response = self.config.athena_client.get_query_execution(
135
+ QueryExecutionId=self.execution_id,
136
+ )
137
+ execution_status = response["QueryExecution"]["Status"]
138
+ state = execution_status["State"]
139
+ logger.debug("Query state: %s", state)
140
+
141
+ match state:
142
+ case "QUEUED" | "RUNNING":
143
+ continue
144
+ case "SUCCEEDED":
145
+ self.has_succeeded = True
146
+ logger.info("Query succeeded")
147
+ return
148
+ case "FAILED" | "CANCELLED":
149
+ msg = f"Query execution failed: {execution_status}"
150
+ raise AthenaQueryError(msg)
151
+ case _:
152
+ msg = f"Unknown state: {state}. Full status: {execution_status}"
153
+ raise AthenaQueryError(msg)
154
+
155
+ # If we've polled as many times as we allow and haven't had a result, cancel
156
+ # the query and throw an error.
157
+ self.stop_execution()
158
+ msg = f"Maximum configured query duration reached. Stopped query: {self.execution_id}"
159
+ raise AthenaQueryError(msg)
160
+
161
+ def _get_responses(self) -> Generator[GetQueryResultsOutputTypeDef]:
162
+ """Get the raw results of the query."""
163
+ # If we don't already know it to have finished, wait for the query to finish
164
+ if not self.has_succeeded:
165
+ self.await_query_completion()
166
+
167
+ if self.execution_id is None:
168
+ msg = "No query id available."
169
+ raise ValueError(msg)
170
+
171
+ response = self.config.athena_client.get_query_results(
172
+ QueryExecutionId=self.execution_id,
173
+ )
174
+ yield response
175
+
176
+ while "NextToken" in response:
177
+ response = self.config.athena_client.get_query_results(
178
+ QueryExecutionId=self.execution_id,
179
+ NextToken=response["NextToken"],
180
+ )
181
+ yield response
182
+
183
+ def parse_response_as_dataframe(self) -> pl.DataFrame:
184
+ """Return the table of results from the query as a dict of columns."""
185
+ # Get the raw query results from the API
186
+ responses = self._get_responses()
187
+
188
+ # Parse the raw row data from the results
189
+ rows = []
190
+ for response in responses:
191
+ for row in response["ResultSet"]["Rows"]:
192
+ row_data = [column.get("VarCharValue") for column in row["Data"]]
193
+ rows.append(row_data)
194
+
195
+ if len(rows) == 0:
196
+ return pl.DataFrame({})
197
+
198
+ # The first row returned by the response is the names of the columns.
199
+ column_names, values_in_rows = rows[0], rows[1:]
200
+
201
+ if len(values_in_rows) == 0:
202
+ return pl.DataFrame(data={col: [] for col in column_names})
203
+
204
+ # Convert a list of rows into a list of columns
205
+ values_in_columns = list(zip(*values_in_rows, strict=True))
206
+
207
+ # Package it up into a dictionary of columns
208
+ return pl.DataFrame(dict(zip(column_names, values_in_columns, strict=True)))
@@ -0,0 +1,334 @@
1
+ """Class to help with AWS environment inference and configuration."""
2
+
3
+ import logging
4
+ import os
5
+ import secrets
6
+ import string
7
+ from pathlib import Path
8
+ from typing import Self
9
+
10
+ import boto3
11
+ import botocore.exceptions
12
+ from cloudpathlib import S3Path
13
+ from mypy_boto3_s3.type_defs import BucketTypeDef
14
+ from mypy_boto3_sts.type_defs import CredentialsTypeDef
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ alphabet = string.ascii_lowercase
20
+
21
+
22
+ def _raise_missing_token() -> None:
23
+ msg = "AWS_WEB_IDENTITY_TOKEN_FILE environment variable not set."
24
+ logger.exception(msg)
25
+ raise RuntimeError(msg)
26
+
27
+
28
+ def _raise_missing_role_arn() -> None:
29
+ raise MissingEnvVarRoleArnError
30
+
31
+
32
+ class MissingEnvVarRoleArnError(Exception):
33
+ """Exception for a missing role arn env var."""
34
+
35
+ def __init__(self) -> None:
36
+ """Raise error with message."""
37
+ self.message = "Add AWS_ROLE_ARN to environment variables."
38
+ super().__init__()
39
+
40
+
41
+ class Environment:
42
+ """AWS environment inference and configuration.
43
+
44
+ This class helps determine the environment (prod, preprod, test, dev),
45
+ manage AWS credentials (via web identity or default),
46
+ and construct environment-specific S3 bucket URLs.
47
+ """
48
+
49
+ _instance: Self | None = None
50
+
51
+ def __new__(cls, *_args: object, **_kwargs: object) -> Self:
52
+ """Singleton instantiation."""
53
+ if cls._instance is None:
54
+ cls._instance = super().__new__(cls)
55
+ return cls._instance
56
+
57
+ def __init__(
58
+ self,
59
+ job_name: str | None = "",
60
+ bucket_prefix: str | None = "emds",
61
+ *,
62
+ use_web_identity: bool,
63
+ ) -> None:
64
+ """Initialize the environment context.
65
+
66
+ Parameters
67
+ ----------
68
+ job_name : str, optional
69
+ Name used when assuming a role session, by default empty string.
70
+ bucket_prefix : str, optional
71
+ Base prefix used to identify datahub buckets, by default "emds".
72
+ use_web_identity : bool, optional
73
+ Whether to try using web identity credentials first.
74
+
75
+ """
76
+ if hasattr(self, "_initialized") and self._initialized:
77
+ return
78
+
79
+ self.job_name = job_name
80
+ self.bucket_prefix = bucket_prefix
81
+ self.use_web_identity = use_web_identity
82
+
83
+ self.session = self._init_session()
84
+ self.alias = self._fetch_account_alias()
85
+ self.account_no = self._get_account_number()
86
+ self.bucket_list = self._list_buckets()
87
+
88
+ self._initialized: bool = True
89
+
90
+ def _init_session(self) -> boto3.Session:
91
+ """Initialise a boto3 session, optionally using web identity credentials.
92
+
93
+ Returns
94
+ -------
95
+ boto3.Session
96
+ Configured boto3 session.
97
+
98
+ """
99
+ if self.use_web_identity:
100
+ try:
101
+ token_path = os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
102
+ if token_path is None:
103
+ _raise_missing_token()
104
+ else:
105
+ path = Path(token_path)
106
+ with path.open() as f:
107
+ web_identity_token = f.read()
108
+
109
+ role_arn = os.environ.get("AWS_ROLE_ARN")
110
+ if role_arn is None:
111
+ _raise_missing_role_arn()
112
+ else:
113
+ role_arn_arg = role_arn
114
+
115
+ sts_client = boto3.client("sts")
116
+ response = sts_client.assume_role_with_web_identity(
117
+ RoleArn=role_arn_arg,
118
+ RoleSessionName=f"session-{self.job_name}",
119
+ WebIdentityToken=web_identity_token,
120
+ DurationSeconds=900,
121
+ )
122
+ return boto3.session.Session(
123
+ aws_access_key_id=response["Credentials"]["AccessKeyId"],
124
+ aws_secret_access_key=response["Credentials"]["SecretAccessKey"],
125
+ aws_session_token=response["Credentials"]["SessionToken"],
126
+ )
127
+ except Exception as e:
128
+ logger.warning("Web identity failed: %s. Falling back to default session.", e)
129
+ raise
130
+
131
+ return boto3.session.Session()
132
+
133
+ def _fetch_account_alias(self) -> str:
134
+ """Fetch the AWS account alias.
135
+
136
+ Returns
137
+ -------
138
+ str
139
+ Account alias or default.
140
+
141
+ Notes
142
+ -----
143
+ Falls back to 'preproduction' alias if none found.
144
+
145
+ """
146
+ try:
147
+ aliases = boto3.client("iam").list_account_aliases().get("AccountAliases", [])
148
+ return aliases[0] if aliases else "electronic-monitoring-data-preproduction"
149
+ except botocore.exceptions.ClientError:
150
+ logger.warning("Failed to fetch account alias, assuming preproduction.")
151
+ return "electronic-monitoring-data-preproduction"
152
+
153
+ def _get_account_number(self) -> str:
154
+ """Return the AWS account number."""
155
+ try:
156
+ return boto3.client("sts").get_caller_identity()["Account"]
157
+ except botocore.exceptions.NoCredentialsError:
158
+ msg = "AWS credentials not found."
159
+ logger.exception(msg)
160
+ raise RuntimeError(msg) from None
161
+
162
+ def _list_buckets(self) -> list[BucketTypeDef]:
163
+ """List all available S3 buckets."""
164
+ try:
165
+ return boto3.client("s3").list_buckets()["Buckets"]
166
+ except Exception as e:
167
+ logger.warning("Could not list buckets: %s", e)
168
+ raise
169
+
170
+ @property
171
+ def account_number(self) -> str:
172
+ """Return the AWS account number."""
173
+ return self.account_no
174
+
175
+ @property
176
+ def environment_name(self) -> str:
177
+ """Infer environment name from account alias.
178
+
179
+ Returns
180
+ -------
181
+ str
182
+ One of: prod, dev, preprod, test, or fallback to raw alias suffix.
183
+
184
+ """
185
+ full_env_name = self.alias.split("-")[-1]
186
+ return {
187
+ "production": "prod",
188
+ "development": "dev",
189
+ "preproduction": "preprod",
190
+ "test": "test",
191
+ }.get(full_env_name, full_env_name)
192
+
193
+ @property
194
+ def is_prod(self) -> bool:
195
+ """Check if the environment is production."""
196
+ return self.environment_name == "prod"
197
+
198
+ def get_full_bucket_url(
199
+ self,
200
+ bucket_prefix: str | None = None,
201
+ *,
202
+ full_prefix: bool,
203
+ ) -> S3Path | None:
204
+ """Get S3Path to bucket matching environment and prefix.
205
+
206
+ Parameters
207
+ ----------
208
+ bucket_prefix : str, optional
209
+ Prefix to search for (overrides default prefix).
210
+ full_prefix : bool, optional
211
+ Whether to match full bucket name exactly.
212
+
213
+ Returns
214
+ -------
215
+ Optional[S3Path]
216
+ S3Path to the matched bucket or None if not found.
217
+
218
+ """
219
+ search_prefix = bucket_prefix or self.bucket_prefix
220
+ expected_name = f"{self.bucket_prefix}-{self.environment_name}-{search_prefix}"
221
+
222
+ for bucket in self.bucket_list:
223
+ bucket_name = bucket["Name"]
224
+ if full_prefix:
225
+ if expected_name == "-".join(bucket_name.split("-")[:-1]) or expected_name == bucket_name:
226
+ return S3Path(f"s3://{bucket_name}")
227
+ elif expected_name in bucket_name:
228
+ return S3Path(f"s3://{bucket_name}")
229
+ return None
230
+
231
+ def get_api_invoke_url(self, api_name: str, region: str) -> str:
232
+ """Get API invoke url from env."""
233
+ client = boto3.client("apigateway", region)
234
+ rest_api_response = client.get_rest_apis()
235
+ matches = [it for it in rest_api_response["items"] if it["name"] == api_name]
236
+ if len(matches) > 1:
237
+ raise ValueError
238
+ api_details = matches[0]
239
+ return f"https://{api_details['id']}.execute-api.{region}.amazonaws.com/"
240
+
241
+ def refresh_credentials(self) -> CredentialsTypeDef:
242
+ """Refresh credentials via STS.
243
+
244
+ Returns
245
+ -------
246
+ dict
247
+ New credentials.
248
+
249
+ """
250
+ try:
251
+ token_path = os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
252
+ if token_path is None:
253
+ _raise_missing_token()
254
+ else:
255
+ path = Path(token_path)
256
+ with path.open() as f:
257
+ web_identity_token = f.read()
258
+
259
+ role_arn = os.environ.get("AWS_ROLE_ARN")
260
+ if role_arn is None:
261
+ _raise_missing_role_arn()
262
+ else:
263
+ roel_arn_arg = role_arn
264
+
265
+ sts_client = boto3.client("sts")
266
+ rand_suffix = "".join(secrets.choice(alphabet) for _ in range(10))
267
+ response_assume_role = sts_client.assume_role_with_web_identity(
268
+ RoleArn=roel_arn_arg,
269
+ RoleSessionName=f"session-{self.job_name}-{rand_suffix}",
270
+ WebIdentityToken=web_identity_token,
271
+ DurationSeconds=900,
272
+ )
273
+ return response_assume_role["Credentials"]
274
+
275
+ except Exception:
276
+ logger.exception("Web identity failed. Falling back to get_session_token.")
277
+
278
+ sts_client = boto3.client("sts")
279
+ response_session_token = sts_client.get_session_token(DurationSeconds=900)
280
+ return response_session_token["Credentials"]
281
+
282
+ def export_dbt_variables(self, *, actions: bool = False, airflow: bool = False) -> None:
283
+ """Export dbt variables for the environment."""
284
+ s3_data_bucket_name = self.get_full_bucket_url("cadt", full_prefix=True)
285
+ dbt_test_profile_workgroup = f"{self.account_number}-default"
286
+ dbt_suffix = "" if self.is_prod else f"_{self.environment_name}_dbt"
287
+ h3_lambda_arn = f"arn:aws:lambda:eu-west-2:{self.account_no}:function:h3-udf"
288
+
289
+ if actions:
290
+ export_suffix = f'echo "DBT_SUFFIX={dbt_suffix}" \
291
+ >> $GITHUB_ENV\n'
292
+ export_bucket = f'echo "S3_DATA_BUCKET_NAME={s3_data_bucket_name}" \
293
+ >> $GITHUB_ENV\n'
294
+ export_dbt_profile = f'echo \
295
+ "DBT_TEST_PROFILE_WORKGROUP={dbt_test_profile_workgroup}"\
296
+ >> $GITHUB_ENV\n'
297
+ export_dbt_profile_location = ""
298
+ export_h3_lambda_arn = f"""echo \
299
+ export H3_LAMBDA_ARN='{h3_lambda_arn}'
300
+ >> $GITHUB_ENV\n
301
+ """
302
+ export_dbt_suffix = f'echo "DBT_SUFFIX={dbt_suffix}" \
303
+ >> $GITHUB_ENV\n'
304
+ else:
305
+ export_suffix = f"export DBT_SUFFIX='{dbt_suffix}'\n"
306
+ export_bucket = f"export S3_DATA_BUCKET_NAME='{s3_data_bucket_name}'\n"
307
+ export_dbt_profile = f"""
308
+ export DBT_TEST_PROFILE_WORKGROUP='{dbt_test_profile_workgroup}'\n
309
+ """
310
+ export_dbt_profile_location = 'export DBT_PROFILES_DIR="../.dbt/"\n'
311
+ export_h3_lambda_arn = f"export H3_LAMBDA_ARN='{h3_lambda_arn}'"
312
+ export_dbt_suffix = f"export DBT_SUFFIX='{dbt_suffix}'\n"
313
+
314
+ with Path("set_env.sh").open("w") as f:
315
+ f.write(export_suffix)
316
+ f.write(export_bucket)
317
+ f.write(export_dbt_profile)
318
+ if not airflow:
319
+ f.write(export_dbt_profile_location)
320
+ f.write(export_h3_lambda_arn)
321
+ f.write(export_dbt_suffix)
322
+
323
+ @classmethod
324
+ def clear(cls) -> None:
325
+ """Reset the singleton instance.
326
+
327
+ Use this to force restart of the class. Mainly for testing.
328
+ """
329
+ cls._instance = None
330
+
331
+ @classmethod
332
+ def instance(cls) -> "Environment | None":
333
+ """Return the current singleton instance, if any. Mainly for testing."""
334
+ return cls._instance
@@ -0,0 +1,232 @@
1
+ """Build safe trino query."""
2
+
3
+ import re
4
+ import textwrap
5
+ from datetime import UTC, date, datetime
6
+ from typing import ClassVar, Literal, Self
7
+
8
+ IDENTIFIER = re.compile(r"^[a-z_][a-z0-9_\$]*$")
9
+ DISALLOWED_CHARACTERS = re.compile(r"[;\n\r]")
10
+ MAX_LIST_LENGTH = 1000
11
+
12
+ Column = str
13
+ OrderByClause = tuple[Column, Literal["ASC", "DESC"]]
14
+ DatesConditions = tuple[Literal["<", ">", "=", "<=", ">="], datetime | str]
15
+
16
+
17
+ def _ensure_safe_fragment(fragment: str) -> None:
18
+ """Raise error for any bad chars."""
19
+ if DISALLOWED_CHARACTERS.search(fragment):
20
+ msg = "Disallowed characters in fragment"
21
+ raise ValueError(msg)
22
+
23
+
24
+ def validate_identifier(name: str) -> str:
25
+ """Ensure a valid SQL/Athena identifier."""
26
+ if not IDENTIFIER.match(name):
27
+ msg = f"Invalid identifier: {name!r}"
28
+ raise ValueError(msg)
29
+ return name
30
+
31
+
32
+ def validate_order_by_direction(direction: str) -> None:
33
+ """Ensure that the given direction is valid for ordering by."""
34
+ if direction.upper() not in ("ASC", "DESC"):
35
+ msg = f"Invalid direction in order by clause: {direction}"
36
+ raise ValueError(msg)
37
+
38
+
39
+ def quote_literal(value: None | str | float | datetime | date) -> str:
40
+ """Safely quote literal values for SQL."""
41
+ if isinstance(value, bool):
42
+ return "TRUE" if value else "FALSE"
43
+ if value is None:
44
+ return "NULL"
45
+ if isinstance(value, (int, float)):
46
+ return str(value)
47
+ if isinstance(value, str):
48
+ _ensure_safe_fragment(value)
49
+ escaped = value.replace("'", "''")
50
+ return f"'{escaped}'"
51
+ if isinstance(value, (date, datetime)):
52
+ return f"cast('{value.strftime('%Y-%m-%d %H:%M:%S')}' as timestamp(6))"
53
+ return None
54
+
55
+
56
+ def quote_list_literal(values: list[str | int | float] | tuple[str | int | float]) -> str:
57
+ """Quote a list into trino."""
58
+ vals = list(values)
59
+ if len(vals) > MAX_LIST_LENGTH:
60
+ msg = "Too many items in list"
61
+ raise ValueError(msg)
62
+ return "(" + ", ".join(quote_literal(v) for v in vals) + ")"
63
+
64
+
65
+ class Query:
66
+ """A safe Athena SQL query builder."""
67
+
68
+ keywords: ClassVar[list[str]] = [
69
+ "WITH",
70
+ "SELECT",
71
+ "FROM",
72
+ "WHERE",
73
+ "GROUP BY",
74
+ "HAVING",
75
+ "ORDER BY",
76
+ "LIMIT",
77
+ ]
78
+
79
+ def __init__(self) -> None:
80
+ """Init function."""
81
+ self.parts: dict[str, list[str]] = {kw: [] for kw in self.keywords}
82
+
83
+ def SELECT(self, *columns: str, distinct: bool = False) -> Self: # noqa: N802
84
+ """Wrap selected columns."""
85
+ for c in columns:
86
+ if c != "*":
87
+ for part in c.split(","):
88
+ validate_identifier(part)
89
+ col_list = [f'"{c}"' if c != "*" else "*" for c in columns]
90
+ if distinct:
91
+ col_list[0] = "DISTINCT " + col_list[0]
92
+ self.parts["SELECT"].extend(col_list)
93
+ return self
94
+
95
+ def FROM(self, database: str, table: str) -> Self: # noqa: N802
96
+ """From database name table name."""
97
+ validate_identifier(database)
98
+ validate_identifier(table)
99
+ self.parts["FROM"].append(f'"{database}"."{table}"')
100
+ return self
101
+
102
+ def WHERE( # noqa: N802
103
+ self, *, unquote: bool = False, **conditions: str | list[str | int | float] | tuple[str | int | float]
104
+ ) -> Self:
105
+ """Where col=value style."""
106
+ for col, val in conditions.items():
107
+ validate_identifier(col)
108
+ if isinstance(val, (list, tuple)):
109
+ expr = f'"{col}" IN {val}' if unquote else f'"{col}" IN {quote_list_literal(val)}'
110
+ else:
111
+ expr = f'"{col}" = {val}' if unquote else f'"{col}" = {quote_literal(val)}'
112
+ self.parts["WHERE"].append(expr)
113
+ return self
114
+
115
+ def WHERE_LIKE( # noqa: N802
116
+ self,
117
+ field_wrapper: Literal["", "upper", "lower"] = "",
118
+ connector: Literal["", "OR", "AND"] = "",
119
+ **conditions: str,
120
+ ) -> Self:
121
+ """Where Like filter."""
122
+ for col, val in conditions.items():
123
+ validate_identifier(col)
124
+ if isinstance(val, list):
125
+ range_exprs = [
126
+ f'{field_wrapper}("{col}") LIKE {field_wrapper}({quote_literal("%" + str(option) + "%")})'
127
+ for option in val
128
+ ]
129
+ expr = "(" + f" {connector} ".join(range_exprs) + ")"
130
+ else:
131
+ expr = f'{field_wrapper}("{col}") LIKE {field_wrapper}({quote_literal("%" + str(val) + "%")})'
132
+ self.parts["WHERE"].append(expr)
133
+ return self
134
+
135
+ def DATES(self, **conditions: DatesConditions) -> Self: # noqa: N802
136
+ """Add specific date filtering
137
+ Accepts conditions where the key is the column name and value is a tuple: (operator, date_string)
138
+ Intended to be used to constructing a bounding range of dates rather than dates between two.
139
+ For dates between two values, see DATE_RANGES.
140
+
141
+ Example:
142
+ -------
143
+ - QUERY().DATES(field_1=('<', 2023-01-01))
144
+
145
+ """ # noqa: D205
146
+ valid_operators = {"<", ">", "=", "<=", ">="}
147
+ expressions = []
148
+ for col, (operator, date_lit) in conditions.items():
149
+ # Validation
150
+ if operator not in valid_operators:
151
+ err_txt = f"Invalid operator: {operator}"
152
+ raise ValueError(err_txt)
153
+ validate_identifier(col)
154
+
155
+ if isinstance(date_lit, str):
156
+ dt = datetime.strptime(date_lit, "%Y-%m-%d %H:%M:%S").astimezone(UTC)
157
+ dt.replace(tzinfo=UTC)
158
+ expr = f'"{col}" {operator} {quote_literal(dt)}'
159
+ elif isinstance(date_lit, datetime):
160
+ date_lit.replace(tzinfo=UTC)
161
+ expr = f'"{col}" {operator} {quote_literal(date_lit)}'
162
+ expressions.append(expr)
163
+ group_expr = "(" + " AND ".join(expressions) + ")"
164
+
165
+ if not expressions:
166
+ self.parts["WHERE"].append("")
167
+ else:
168
+ self.parts["WHERE"].append(group_expr)
169
+ return self
170
+
171
+ def DATE_RANGES(self, cols: list[str], start: datetime, end: datetime) -> Self: # noqa: N802
172
+ """Add a grouped OR condition for date ranges: (col1 BETWEEN ... OR col2 BETWEEN ...)."""
173
+ for c in cols:
174
+ validate_identifier(c)
175
+
176
+ start_lit = quote_literal(start)
177
+ end_lit = quote_literal(end)
178
+
179
+ range_exprs = [f'"{c}" BETWEEN {start_lit} AND {end_lit}' for c in cols]
180
+ group_expr = "(" + " OR ".join(range_exprs) + ")"
181
+ self.parts["WHERE"].append(group_expr)
182
+ return self
183
+
184
+ def ORDER_BY(self, *order_by_clauses: str | OrderByClause) -> Self: # noqa: N802
185
+ """Order by action."""
186
+ for clause in order_by_clauses:
187
+ if isinstance(clause, tuple):
188
+ col, direction = clause
189
+ for part in col.split("."):
190
+ validate_identifier(part)
191
+ validate_order_by_direction(direction)
192
+ self.parts["ORDER BY"].append(f'"{col}" {direction}')
193
+ else:
194
+ # Allow ordering by columns without direction
195
+ col = clause
196
+ for part in col.split("."):
197
+ validate_identifier(part)
198
+ self.parts["ORDER BY"].append(f'"{col}"')
199
+
200
+ return self
201
+
202
+ def LIMIT(self, n: int) -> Self: # noqa: N802
203
+ """Limit query."""
204
+ if not isinstance(n, int) or n < 0:
205
+ msg = "LIMIT must be a non-negative integer"
206
+ raise ValueError(msg)
207
+ self.parts["LIMIT"].append(str(n))
208
+ return self
209
+
210
+ def __str__(self) -> str:
211
+ """Create query."""
212
+ sql_lines = []
213
+ for kw in self.keywords:
214
+ vals = self.parts[kw]
215
+ if not vals:
216
+ continue
217
+ match kw:
218
+ case "SELECT" | "ORDER BY" | "GROUP BY":
219
+ sql_lines.append(f"{kw} " + ", ".join(vals))
220
+ case "FROM":
221
+ sql_lines.append(f"{kw} " + ", ".join(vals))
222
+ case "WHERE" | "HAVING":
223
+ sql_lines.append(f"{kw} " + " AND ".join(vals))
224
+ case "LIMIT":
225
+ sql_lines.append(f"{kw} {vals[0]}")
226
+ case _:
227
+ sql_lines.append(f"{kw} {' '.join(vals)}")
228
+ return "\n".join(sql_lines)
229
+
230
+ def pretty(self) -> str:
231
+ """Make SQL look nice."""
232
+ return textwrap.indent(str(self), " ")