data-factory-utils 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_factory_utils-0.3.0/PKG-INFO +54 -0
- data_factory_utils-0.3.0/README.md +38 -0
- data_factory_utils-0.3.0/pyproject.toml +87 -0
- data_factory_utils-0.3.0/src/data_factory_utils/__init__.py +1 -0
- data_factory_utils-0.3.0/src/data_factory_utils/athena.py +208 -0
- data_factory_utils-0.3.0/src/data_factory_utils/environment.py +334 -0
- data_factory_utils-0.3.0/src/data_factory_utils/query.py +232 -0
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: data-factory-utils
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: Utility functions for interacting with data factories.
|
|
5
|
+
Requires-Dist: boto3>=1.42.8
|
|
6
|
+
Requires-Dist: boto3-stubs>=1.42.89
|
|
7
|
+
Requires-Dist: botocore>=1.42.7
|
|
8
|
+
Requires-Dist: cloudpathlib>=0.23.0
|
|
9
|
+
Requires-Dist: mypy-boto3>=1.42.3
|
|
10
|
+
Requires-Dist: mypy-boto3-athena>=1.42.43
|
|
11
|
+
Requires-Dist: mypy-boto3-s3>=1.42.85
|
|
12
|
+
Requires-Dist: mypy-boto3-sts>=1.42.3
|
|
13
|
+
Requires-Dist: polars>=1.40.0
|
|
14
|
+
Requires-Python: >=3.12
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
|
|
17
|
+
# data-factory-utils
|
|
18
|
+
A package for random utils for data factories.
|
|
19
|
+
|
|
20
|
+
## Installation
|
|
21
|
+
This is a published package. Install using your favourite installation method.
|
|
22
|
+
|
|
23
|
+
```bash
|
|
24
|
+
uv add data-factory-utils
|
|
25
|
+
pip install data-factory-utils
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
## Usage
|
|
29
|
+
### Environment functions
|
|
30
|
+
This set of functions reads from your data factory dynamically. It should infer the environment you are in as well.
|
|
31
|
+
|
|
32
|
+
No matter how many times you initiate the class, it will re-use old variables. To do so...
|
|
33
|
+
```python
|
|
34
|
+
from data_factory_utils.environment import Environment
|
|
35
|
+
env = Environment()
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
To return information about the environment (if we are in development with account number 0101010101):
|
|
39
|
+
```python
|
|
40
|
+
env.account_no
|
|
41
|
+
# 0101010101
|
|
42
|
+
env.environment_name
|
|
43
|
+
# dev
|
|
44
|
+
env.is_prod
|
|
45
|
+
# False
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
To get an S3 bucket name (outputted as `cloudpathlib`'s `S3Path`) (let us imagine here that the name is `emds-dev-random-name-202512161154001309058001`):
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
s3_random_name_bucket = env.get_full_bucket_url("random-name", full_prefix=True)
|
|
52
|
+
print(str(s3_random_name_bucket.bucket))
|
|
53
|
+
# emds-dev-random-name-202512161154001309058001
|
|
54
|
+
```
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# data-factory-utils
|
|
2
|
+
A package for random utils for data factories.
|
|
3
|
+
|
|
4
|
+
## Installation
|
|
5
|
+
This is a published package. Install using your favourite installation method.
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
uv add data-factory-utils
|
|
9
|
+
pip install data-factory-utils
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
## Usage
|
|
13
|
+
### Environment functions
|
|
14
|
+
This set of functions reads from your data factory dynamically. It should infer the environment you are in as well.
|
|
15
|
+
|
|
16
|
+
No matter how many times you initiate the class, it will re-use old variables. To do so...
|
|
17
|
+
```python
|
|
18
|
+
from data_factory_utils.environment import Environment
|
|
19
|
+
env = Environment()
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
To return information about the environment (if we are in development with account number 0101010101):
|
|
23
|
+
```python
|
|
24
|
+
env.account_no
|
|
25
|
+
# 0101010101
|
|
26
|
+
env.environment_name
|
|
27
|
+
# dev
|
|
28
|
+
env.is_prod
|
|
29
|
+
# False
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
To get an S3 bucket name (outputted as `cloudpathlib`'s `S3Path`) (let us imagine here that the name is `emds-dev-random-name-202512161154001309058001`):
|
|
33
|
+
|
|
34
|
+
```python
|
|
35
|
+
s3_random_name_bucket = env.get_full_bucket_url("random-name", full_prefix=True)
|
|
36
|
+
print(str(s3_random_name_bucket.bucket))
|
|
37
|
+
# emds-dev-random-name-202512161154001309058001
|
|
38
|
+
```
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "data-factory-utils"
|
|
3
|
+
version = "0.3.0"
|
|
4
|
+
description = "Utility functions for interacting with data factories."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.12"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"boto3>=1.42.8",
|
|
9
|
+
"boto3-stubs>=1.42.89",
|
|
10
|
+
"botocore>=1.42.7",
|
|
11
|
+
"cloudpathlib>=0.23.0",
|
|
12
|
+
"mypy-boto3>=1.42.3",
|
|
13
|
+
"mypy-boto3-athena>=1.42.43",
|
|
14
|
+
"mypy-boto3-s3>=1.42.85",
|
|
15
|
+
"mypy-boto3-sts>=1.42.3",
|
|
16
|
+
"polars>=1.40.0",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
[build-system]
|
|
20
|
+
requires = ["uv_build>=0.9.17,<0.10.0"]
|
|
21
|
+
build-backend = "uv_build"
|
|
22
|
+
|
|
23
|
+
[dependency-groups]
|
|
24
|
+
dev = [
|
|
25
|
+
"mypy>=1.20.1",
|
|
26
|
+
"prek>=0.2.21",
|
|
27
|
+
"ruff>=0.14.8",
|
|
28
|
+
"toml-cli>=0.8.2",
|
|
29
|
+
"ty>=0.0.1a33",
|
|
30
|
+
]
|
|
31
|
+
test = [
|
|
32
|
+
"moto>=5.1.18",
|
|
33
|
+
"pytest>=9.0.2",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
[tool.ruff]
|
|
38
|
+
line-length = 120
|
|
39
|
+
|
|
40
|
+
[tool.bandit]
|
|
41
|
+
exclude_dirs = ["/tests", "/.venv"]
|
|
42
|
+
|
|
43
|
+
[tool.mypy]
|
|
44
|
+
strict = true
|
|
45
|
+
namespace_packages = false
|
|
46
|
+
disallow_untyped_defs = true
|
|
47
|
+
follow_untyped_imports = true
|
|
48
|
+
exclude = ["tests"]
|
|
49
|
+
|
|
50
|
+
[tool.ruff.lint]
|
|
51
|
+
select = ["ALL"]
|
|
52
|
+
# Remove warnings
|
|
53
|
+
ignore = ["D203", "D213", "COM812"]
|
|
54
|
+
|
|
55
|
+
[tool.ruff.lint.per-file-ignores]
|
|
56
|
+
"tests/**.py" = ["S101"]
|
|
57
|
+
|
|
58
|
+
[tool.semantic_release]
|
|
59
|
+
commit_message = "{version}\n\nAutomatically generated by python-semantic-release"
|
|
60
|
+
commit_parser = "conventional"
|
|
61
|
+
logging_use_named_masks = false
|
|
62
|
+
major_on_zero = false
|
|
63
|
+
allow_zero_version = true
|
|
64
|
+
no_git_verify = false
|
|
65
|
+
tag_format = "{version}"
|
|
66
|
+
|
|
67
|
+
[tool.semantic_release.branches.main]
|
|
68
|
+
match = "main"
|
|
69
|
+
prerelease_token = "rc"
|
|
70
|
+
prerelease = false
|
|
71
|
+
|
|
72
|
+
[tool.semantic_release.branches.other]
|
|
73
|
+
match = ".*"
|
|
74
|
+
prerelease_token = "rc"
|
|
75
|
+
prerelease = true
|
|
76
|
+
|
|
77
|
+
[tool.semantic_release.commit_parser_options]
|
|
78
|
+
minor_tags = ["feat"]
|
|
79
|
+
patch_tags = ["fix", "perf"]
|
|
80
|
+
other_allowed_tags = ["build", "chore", "ci", "docs", "style", "refactor", "test"]
|
|
81
|
+
allowed_tags = ["feat", "fix", "perf", "build", "chore", "ci", "docs", "style", "refactor", "test"]
|
|
82
|
+
default_bump_level = 0
|
|
83
|
+
parse_squash_commits = true
|
|
84
|
+
ignore_merge_commits = true
|
|
85
|
+
|
|
86
|
+
[tool.bandit.assert_used]
|
|
87
|
+
skips = ['*_test.py', '*/test_*.py']
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Init file."""
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
"""Athena helper functions."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
from collections.abc import Generator
|
|
6
|
+
|
|
7
|
+
import boto3
|
|
8
|
+
import polars as pl
|
|
9
|
+
from mypy_boto3_athena.client import AthenaClient
|
|
10
|
+
from mypy_boto3_athena.type_defs import GetQueryResultsOutputTypeDef, QueryExecutionContextTypeDef
|
|
11
|
+
|
|
12
|
+
from data_factory_utils.environment import Environment
|
|
13
|
+
from data_factory_utils.query import Query
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(level=logging.INFO)
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class NoQueryIdError(Exception):
|
|
20
|
+
"""Exception for no query id being returned."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AthenaQueryError(Exception):
|
|
24
|
+
"""Exception for raising issues with Athena query execution.
|
|
25
|
+
|
|
26
|
+
boto3 itself doesn't provide a generic 'query failed' error, so we create our own to raise when queries fail.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AthenaConfig:
|
|
31
|
+
"""Container for Athena configuration.
|
|
32
|
+
|
|
33
|
+
It's slow to create so we do it once and share it between queries.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
athena_client: AthenaClient
|
|
37
|
+
result_bucket: str
|
|
38
|
+
database: str
|
|
39
|
+
catalog: str | None
|
|
40
|
+
workgroup: str
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
database: str,
|
|
45
|
+
catalog: str | None = None,
|
|
46
|
+
result_bucket_name: str = "athena-query-results",
|
|
47
|
+
environment: Environment | None = None,
|
|
48
|
+
) -> None:
|
|
49
|
+
"""Initialise class."""
|
|
50
|
+
if environment is None:
|
|
51
|
+
environment = Environment(use_web_identity=False)
|
|
52
|
+
self.result_bucket = str(environment.get_full_bucket_url(result_bucket_name, full_prefix=True))
|
|
53
|
+
self.athena_client = boto3.client("athena")
|
|
54
|
+
self.database = database
|
|
55
|
+
self.catalog = catalog
|
|
56
|
+
self.workgroup = f"{environment.account_number}-ears-sars"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class AthenaQuery:
|
|
60
|
+
"""Abstraction of an Athena query.
|
|
61
|
+
|
|
62
|
+
boto3's API triggers the query, gets the query status and retrieves the response as separate calls.
|
|
63
|
+
AthenaQuery implements methods to simplify this flow.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
query: Query
|
|
67
|
+
execution_id: str | None
|
|
68
|
+
config: AthenaConfig
|
|
69
|
+
has_succeeded: bool
|
|
70
|
+
|
|
71
|
+
def __init__(self, query: Query, config: AthenaConfig) -> None:
|
|
72
|
+
"""Initialise class."""
|
|
73
|
+
self.query = query
|
|
74
|
+
self.execution_id = None
|
|
75
|
+
self.config = config
|
|
76
|
+
self.has_succeeded = False
|
|
77
|
+
|
|
78
|
+
def run_and_await(self) -> None:
|
|
79
|
+
"""Run query, wait for it to complete."""
|
|
80
|
+
self.execution_id = self.start_execution()
|
|
81
|
+
self.await_query_completion()
|
|
82
|
+
|
|
83
|
+
def start_execution(self) -> str:
|
|
84
|
+
"""Begin an Athena query with standard settings. Return the query ID."""
|
|
85
|
+
query_context: QueryExecutionContextTypeDef = {"Database": self.config.database}
|
|
86
|
+
if self.config.catalog is not None:
|
|
87
|
+
query_context["Catalog"] = self.config.catalog
|
|
88
|
+
|
|
89
|
+
query_string = str(self.query).strip()
|
|
90
|
+
if not query_string:
|
|
91
|
+
msg = "Query string cannot be empty"
|
|
92
|
+
raise ValueError(msg)
|
|
93
|
+
|
|
94
|
+
logger.debug(query_string)
|
|
95
|
+
response = self.config.athena_client.start_query_execution(
|
|
96
|
+
QueryString=str(self.query),
|
|
97
|
+
ResultConfiguration={"OutputLocation": self.config.result_bucket},
|
|
98
|
+
QueryExecutionContext=query_context,
|
|
99
|
+
WorkGroup=self.config.workgroup,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
self.execution_id = str(response["QueryExecutionId"])
|
|
103
|
+
if not self.execution_id:
|
|
104
|
+
msg = "No Query Execution Id. Response: %s"
|
|
105
|
+
raise NoQueryIdError(msg, response)
|
|
106
|
+
|
|
107
|
+
logger.info("Query Execution ID: %s", self.execution_id)
|
|
108
|
+
return self.execution_id
|
|
109
|
+
|
|
110
|
+
def stop_execution(self) -> None:
|
|
111
|
+
"""Stop an Athena query's execution."""
|
|
112
|
+
if self.execution_id is None:
|
|
113
|
+
msg = "No query id available. Make sure that the query has been startedbefore attempting to stop it."
|
|
114
|
+
raise AthenaQueryError(msg)
|
|
115
|
+
|
|
116
|
+
self.config.athena_client.stop_query_execution(QueryExecutionId=self.execution_id)
|
|
117
|
+
|
|
118
|
+
def await_query_completion(
|
|
119
|
+
self,
|
|
120
|
+
max_iterations: int = 20,
|
|
121
|
+
max_interval_seconds: int = 120,
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Wait for a managed query with the given ID to finish and return the result."""
|
|
124
|
+
if self.execution_id is None:
|
|
125
|
+
msg = "No query id available. Make sure that the query has been startedbefore querying the result."
|
|
126
|
+
raise ValueError(msg)
|
|
127
|
+
|
|
128
|
+
for i in range(max_iterations):
|
|
129
|
+
# Exponential backoff on checking query result
|
|
130
|
+
wait_time = min(max_interval_seconds, 2**i)
|
|
131
|
+
time.sleep(wait_time)
|
|
132
|
+
|
|
133
|
+
# Check whether the query has terminated
|
|
134
|
+
response = self.config.athena_client.get_query_execution(
|
|
135
|
+
QueryExecutionId=self.execution_id,
|
|
136
|
+
)
|
|
137
|
+
execution_status = response["QueryExecution"]["Status"]
|
|
138
|
+
state = execution_status["State"]
|
|
139
|
+
logger.debug("Query state: %s", state)
|
|
140
|
+
|
|
141
|
+
match state:
|
|
142
|
+
case "QUEUED" | "RUNNING":
|
|
143
|
+
continue
|
|
144
|
+
case "SUCCEEDED":
|
|
145
|
+
self.has_succeeded = True
|
|
146
|
+
logger.info("Query succeeded")
|
|
147
|
+
return
|
|
148
|
+
case "FAILED" | "CANCELLED":
|
|
149
|
+
msg = f"Query execution failed: {execution_status}"
|
|
150
|
+
raise AthenaQueryError(msg)
|
|
151
|
+
case _:
|
|
152
|
+
msg = f"Unknown state: {state}. Full status: {execution_status}"
|
|
153
|
+
raise AthenaQueryError(msg)
|
|
154
|
+
|
|
155
|
+
# If we've polled as many times as we allow and haven't had a result, cancel
|
|
156
|
+
# the query and throw an error.
|
|
157
|
+
self.stop_execution()
|
|
158
|
+
msg = f"Maximum configured query duration reached. Stopped query: {self.execution_id}"
|
|
159
|
+
raise AthenaQueryError(msg)
|
|
160
|
+
|
|
161
|
+
def _get_responses(self) -> Generator[GetQueryResultsOutputTypeDef]:
|
|
162
|
+
"""Get the raw results of the query."""
|
|
163
|
+
# If we don't already know it to have finished, wait for the query to finish
|
|
164
|
+
if not self.has_succeeded:
|
|
165
|
+
self.await_query_completion()
|
|
166
|
+
|
|
167
|
+
if self.execution_id is None:
|
|
168
|
+
msg = "No query id available."
|
|
169
|
+
raise ValueError(msg)
|
|
170
|
+
|
|
171
|
+
response = self.config.athena_client.get_query_results(
|
|
172
|
+
QueryExecutionId=self.execution_id,
|
|
173
|
+
)
|
|
174
|
+
yield response
|
|
175
|
+
|
|
176
|
+
while "NextToken" in response:
|
|
177
|
+
response = self.config.athena_client.get_query_results(
|
|
178
|
+
QueryExecutionId=self.execution_id,
|
|
179
|
+
NextToken=response["NextToken"],
|
|
180
|
+
)
|
|
181
|
+
yield response
|
|
182
|
+
|
|
183
|
+
def parse_response_as_dataframe(self) -> pl.DataFrame:
|
|
184
|
+
"""Return the table of results from the query as a dict of columns."""
|
|
185
|
+
# Get the raw query results from the API
|
|
186
|
+
responses = self._get_responses()
|
|
187
|
+
|
|
188
|
+
# Parse the raw row data from the results
|
|
189
|
+
rows = []
|
|
190
|
+
for response in responses:
|
|
191
|
+
for row in response["ResultSet"]["Rows"]:
|
|
192
|
+
row_data = [column.get("VarCharValue") for column in row["Data"]]
|
|
193
|
+
rows.append(row_data)
|
|
194
|
+
|
|
195
|
+
if len(rows) == 0:
|
|
196
|
+
return pl.DataFrame({})
|
|
197
|
+
|
|
198
|
+
# The first row returned by the response is the names of the columns.
|
|
199
|
+
column_names, values_in_rows = rows[0], rows[1:]
|
|
200
|
+
|
|
201
|
+
if len(values_in_rows) == 0:
|
|
202
|
+
return pl.DataFrame(data={col: [] for col in column_names})
|
|
203
|
+
|
|
204
|
+
# Convert a list of rows into a list of columns
|
|
205
|
+
values_in_columns = list(zip(*values_in_rows, strict=True))
|
|
206
|
+
|
|
207
|
+
# Package it up into a dictionary of columns
|
|
208
|
+
return pl.DataFrame(dict(zip(column_names, values_in_columns, strict=True)))
|
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
"""Class to help with AWS environment inference and configuration."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import secrets
|
|
6
|
+
import string
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Self
|
|
9
|
+
|
|
10
|
+
import boto3
|
|
11
|
+
import botocore.exceptions
|
|
12
|
+
from cloudpathlib import S3Path
|
|
13
|
+
from mypy_boto3_s3.type_defs import BucketTypeDef
|
|
14
|
+
from mypy_boto3_sts.type_defs import CredentialsTypeDef
|
|
15
|
+
|
|
16
|
+
logging.basicConfig(level=logging.INFO)
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
alphabet = string.ascii_lowercase
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _raise_missing_token() -> None:
|
|
23
|
+
msg = "AWS_WEB_IDENTITY_TOKEN_FILE environment variable not set."
|
|
24
|
+
logger.exception(msg)
|
|
25
|
+
raise RuntimeError(msg)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _raise_missing_role_arn() -> None:
|
|
29
|
+
raise MissingEnvVarRoleArnError
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MissingEnvVarRoleArnError(Exception):
|
|
33
|
+
"""Exception for a missing role arn env var."""
|
|
34
|
+
|
|
35
|
+
def __init__(self) -> None:
|
|
36
|
+
"""Raise error with message."""
|
|
37
|
+
self.message = "Add AWS_ROLE_ARN to environment variables."
|
|
38
|
+
super().__init__()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Environment:
|
|
42
|
+
"""AWS environment inference and configuration.
|
|
43
|
+
|
|
44
|
+
This class helps determine the environment (prod, preprod, test, dev),
|
|
45
|
+
manage AWS credentials (via web identity or default),
|
|
46
|
+
and construct environment-specific S3 bucket URLs.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
_instance: Self | None = None
|
|
50
|
+
|
|
51
|
+
def __new__(cls, *_args: object, **_kwargs: object) -> Self:
|
|
52
|
+
"""Singleton instantiation."""
|
|
53
|
+
if cls._instance is None:
|
|
54
|
+
cls._instance = super().__new__(cls)
|
|
55
|
+
return cls._instance
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
job_name: str | None = "",
|
|
60
|
+
bucket_prefix: str | None = "emds",
|
|
61
|
+
*,
|
|
62
|
+
use_web_identity: bool,
|
|
63
|
+
) -> None:
|
|
64
|
+
"""Initialize the environment context.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
job_name : str, optional
|
|
69
|
+
Name used when assuming a role session, by default empty string.
|
|
70
|
+
bucket_prefix : str, optional
|
|
71
|
+
Base prefix used to identify datahub buckets, by default "emds".
|
|
72
|
+
use_web_identity : bool, optional
|
|
73
|
+
Whether to try using web identity credentials first.
|
|
74
|
+
|
|
75
|
+
"""
|
|
76
|
+
if hasattr(self, "_initialized") and self._initialized:
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
self.job_name = job_name
|
|
80
|
+
self.bucket_prefix = bucket_prefix
|
|
81
|
+
self.use_web_identity = use_web_identity
|
|
82
|
+
|
|
83
|
+
self.session = self._init_session()
|
|
84
|
+
self.alias = self._fetch_account_alias()
|
|
85
|
+
self.account_no = self._get_account_number()
|
|
86
|
+
self.bucket_list = self._list_buckets()
|
|
87
|
+
|
|
88
|
+
self._initialized: bool = True
|
|
89
|
+
|
|
90
|
+
def _init_session(self) -> boto3.Session:
|
|
91
|
+
"""Initialise a boto3 session, optionally using web identity credentials.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
boto3.Session
|
|
96
|
+
Configured boto3 session.
|
|
97
|
+
|
|
98
|
+
"""
|
|
99
|
+
if self.use_web_identity:
|
|
100
|
+
try:
|
|
101
|
+
token_path = os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
|
|
102
|
+
if token_path is None:
|
|
103
|
+
_raise_missing_token()
|
|
104
|
+
else:
|
|
105
|
+
path = Path(token_path)
|
|
106
|
+
with path.open() as f:
|
|
107
|
+
web_identity_token = f.read()
|
|
108
|
+
|
|
109
|
+
role_arn = os.environ.get("AWS_ROLE_ARN")
|
|
110
|
+
if role_arn is None:
|
|
111
|
+
_raise_missing_role_arn()
|
|
112
|
+
else:
|
|
113
|
+
role_arn_arg = role_arn
|
|
114
|
+
|
|
115
|
+
sts_client = boto3.client("sts")
|
|
116
|
+
response = sts_client.assume_role_with_web_identity(
|
|
117
|
+
RoleArn=role_arn_arg,
|
|
118
|
+
RoleSessionName=f"session-{self.job_name}",
|
|
119
|
+
WebIdentityToken=web_identity_token,
|
|
120
|
+
DurationSeconds=900,
|
|
121
|
+
)
|
|
122
|
+
return boto3.session.Session(
|
|
123
|
+
aws_access_key_id=response["Credentials"]["AccessKeyId"],
|
|
124
|
+
aws_secret_access_key=response["Credentials"]["SecretAccessKey"],
|
|
125
|
+
aws_session_token=response["Credentials"]["SessionToken"],
|
|
126
|
+
)
|
|
127
|
+
except Exception as e:
|
|
128
|
+
logger.warning("Web identity failed: %s. Falling back to default session.", e)
|
|
129
|
+
raise
|
|
130
|
+
|
|
131
|
+
return boto3.session.Session()
|
|
132
|
+
|
|
133
|
+
def _fetch_account_alias(self) -> str:
|
|
134
|
+
"""Fetch the AWS account alias.
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
str
|
|
139
|
+
Account alias or default.
|
|
140
|
+
|
|
141
|
+
Notes
|
|
142
|
+
-----
|
|
143
|
+
Falls back to 'preproduction' alias if none found.
|
|
144
|
+
|
|
145
|
+
"""
|
|
146
|
+
try:
|
|
147
|
+
aliases = boto3.client("iam").list_account_aliases().get("AccountAliases", [])
|
|
148
|
+
return aliases[0] if aliases else "electronic-monitoring-data-preproduction"
|
|
149
|
+
except botocore.exceptions.ClientError:
|
|
150
|
+
logger.warning("Failed to fetch account alias, assuming preproduction.")
|
|
151
|
+
return "electronic-monitoring-data-preproduction"
|
|
152
|
+
|
|
153
|
+
def _get_account_number(self) -> str:
|
|
154
|
+
"""Return the AWS account number."""
|
|
155
|
+
try:
|
|
156
|
+
return boto3.client("sts").get_caller_identity()["Account"]
|
|
157
|
+
except botocore.exceptions.NoCredentialsError:
|
|
158
|
+
msg = "AWS credentials not found."
|
|
159
|
+
logger.exception(msg)
|
|
160
|
+
raise RuntimeError(msg) from None
|
|
161
|
+
|
|
162
|
+
def _list_buckets(self) -> list[BucketTypeDef]:
|
|
163
|
+
"""List all available S3 buckets."""
|
|
164
|
+
try:
|
|
165
|
+
return boto3.client("s3").list_buckets()["Buckets"]
|
|
166
|
+
except Exception as e:
|
|
167
|
+
logger.warning("Could not list buckets: %s", e)
|
|
168
|
+
raise
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def account_number(self) -> str:
|
|
172
|
+
"""Return the AWS account number."""
|
|
173
|
+
return self.account_no
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def environment_name(self) -> str:
|
|
177
|
+
"""Infer environment name from account alias.
|
|
178
|
+
|
|
179
|
+
Returns
|
|
180
|
+
-------
|
|
181
|
+
str
|
|
182
|
+
One of: prod, dev, preprod, test, or fallback to raw alias suffix.
|
|
183
|
+
|
|
184
|
+
"""
|
|
185
|
+
full_env_name = self.alias.split("-")[-1]
|
|
186
|
+
return {
|
|
187
|
+
"production": "prod",
|
|
188
|
+
"development": "dev",
|
|
189
|
+
"preproduction": "preprod",
|
|
190
|
+
"test": "test",
|
|
191
|
+
}.get(full_env_name, full_env_name)
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def is_prod(self) -> bool:
|
|
195
|
+
"""Check if the environment is production."""
|
|
196
|
+
return self.environment_name == "prod"
|
|
197
|
+
|
|
198
|
+
def get_full_bucket_url(
|
|
199
|
+
self,
|
|
200
|
+
bucket_prefix: str | None = None,
|
|
201
|
+
*,
|
|
202
|
+
full_prefix: bool,
|
|
203
|
+
) -> S3Path | None:
|
|
204
|
+
"""Get S3Path to bucket matching environment and prefix.
|
|
205
|
+
|
|
206
|
+
Parameters
|
|
207
|
+
----------
|
|
208
|
+
bucket_prefix : str, optional
|
|
209
|
+
Prefix to search for (overrides default prefix).
|
|
210
|
+
full_prefix : bool, optional
|
|
211
|
+
Whether to match full bucket name exactly.
|
|
212
|
+
|
|
213
|
+
Returns
|
|
214
|
+
-------
|
|
215
|
+
Optional[S3Path]
|
|
216
|
+
S3Path to the matched bucket or None if not found.
|
|
217
|
+
|
|
218
|
+
"""
|
|
219
|
+
search_prefix = bucket_prefix or self.bucket_prefix
|
|
220
|
+
expected_name = f"{self.bucket_prefix}-{self.environment_name}-{search_prefix}"
|
|
221
|
+
|
|
222
|
+
for bucket in self.bucket_list:
|
|
223
|
+
bucket_name = bucket["Name"]
|
|
224
|
+
if full_prefix:
|
|
225
|
+
if expected_name == "-".join(bucket_name.split("-")[:-1]) or expected_name == bucket_name:
|
|
226
|
+
return S3Path(f"s3://{bucket_name}")
|
|
227
|
+
elif expected_name in bucket_name:
|
|
228
|
+
return S3Path(f"s3://{bucket_name}")
|
|
229
|
+
return None
|
|
230
|
+
|
|
231
|
+
def get_api_invoke_url(self, api_name: str, region: str) -> str:
|
|
232
|
+
"""Get API invoke url from env."""
|
|
233
|
+
client = boto3.client("apigateway", region)
|
|
234
|
+
rest_api_response = client.get_rest_apis()
|
|
235
|
+
matches = [it for it in rest_api_response["items"] if it["name"] == api_name]
|
|
236
|
+
if len(matches) > 1:
|
|
237
|
+
raise ValueError
|
|
238
|
+
api_details = matches[0]
|
|
239
|
+
return f"https://{api_details['id']}.execute-api.{region}.amazonaws.com/"
|
|
240
|
+
|
|
241
|
+
def refresh_credentials(self) -> CredentialsTypeDef:
|
|
242
|
+
"""Refresh credentials via STS.
|
|
243
|
+
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
dict
|
|
247
|
+
New credentials.
|
|
248
|
+
|
|
249
|
+
"""
|
|
250
|
+
try:
|
|
251
|
+
token_path = os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
|
|
252
|
+
if token_path is None:
|
|
253
|
+
_raise_missing_token()
|
|
254
|
+
else:
|
|
255
|
+
path = Path(token_path)
|
|
256
|
+
with path.open() as f:
|
|
257
|
+
web_identity_token = f.read()
|
|
258
|
+
|
|
259
|
+
role_arn = os.environ.get("AWS_ROLE_ARN")
|
|
260
|
+
if role_arn is None:
|
|
261
|
+
_raise_missing_role_arn()
|
|
262
|
+
else:
|
|
263
|
+
roel_arn_arg = role_arn
|
|
264
|
+
|
|
265
|
+
sts_client = boto3.client("sts")
|
|
266
|
+
rand_suffix = "".join(secrets.choice(alphabet) for _ in range(10))
|
|
267
|
+
response_assume_role = sts_client.assume_role_with_web_identity(
|
|
268
|
+
RoleArn=roel_arn_arg,
|
|
269
|
+
RoleSessionName=f"session-{self.job_name}-{rand_suffix}",
|
|
270
|
+
WebIdentityToken=web_identity_token,
|
|
271
|
+
DurationSeconds=900,
|
|
272
|
+
)
|
|
273
|
+
return response_assume_role["Credentials"]
|
|
274
|
+
|
|
275
|
+
except Exception:
|
|
276
|
+
logger.exception("Web identity failed. Falling back to get_session_token.")
|
|
277
|
+
|
|
278
|
+
sts_client = boto3.client("sts")
|
|
279
|
+
response_session_token = sts_client.get_session_token(DurationSeconds=900)
|
|
280
|
+
return response_session_token["Credentials"]
|
|
281
|
+
|
|
282
|
+
def export_dbt_variables(self, *, actions: bool = False, airflow: bool = False) -> None:
|
|
283
|
+
"""Export dbt variables for the environment."""
|
|
284
|
+
s3_data_bucket_name = self.get_full_bucket_url("cadt", full_prefix=True)
|
|
285
|
+
dbt_test_profile_workgroup = f"{self.account_number}-default"
|
|
286
|
+
dbt_suffix = "" if self.is_prod else f"_{self.environment_name}_dbt"
|
|
287
|
+
h3_lambda_arn = f"arn:aws:lambda:eu-west-2:{self.account_no}:function:h3-udf"
|
|
288
|
+
|
|
289
|
+
if actions:
|
|
290
|
+
export_suffix = f'echo "DBT_SUFFIX={dbt_suffix}" \
|
|
291
|
+
>> $GITHUB_ENV\n'
|
|
292
|
+
export_bucket = f'echo "S3_DATA_BUCKET_NAME={s3_data_bucket_name}" \
|
|
293
|
+
>> $GITHUB_ENV\n'
|
|
294
|
+
export_dbt_profile = f'echo \
|
|
295
|
+
"DBT_TEST_PROFILE_WORKGROUP={dbt_test_profile_workgroup}"\
|
|
296
|
+
>> $GITHUB_ENV\n'
|
|
297
|
+
export_dbt_profile_location = ""
|
|
298
|
+
export_h3_lambda_arn = f"""echo \
|
|
299
|
+
export H3_LAMBDA_ARN='{h3_lambda_arn}'
|
|
300
|
+
>> $GITHUB_ENV\n
|
|
301
|
+
"""
|
|
302
|
+
export_dbt_suffix = f'echo "DBT_SUFFIX={dbt_suffix}" \
|
|
303
|
+
>> $GITHUB_ENV\n'
|
|
304
|
+
else:
|
|
305
|
+
export_suffix = f"export DBT_SUFFIX='{dbt_suffix}'\n"
|
|
306
|
+
export_bucket = f"export S3_DATA_BUCKET_NAME='{s3_data_bucket_name}'\n"
|
|
307
|
+
export_dbt_profile = f"""
|
|
308
|
+
export DBT_TEST_PROFILE_WORKGROUP='{dbt_test_profile_workgroup}'\n
|
|
309
|
+
"""
|
|
310
|
+
export_dbt_profile_location = 'export DBT_PROFILES_DIR="../.dbt/"\n'
|
|
311
|
+
export_h3_lambda_arn = f"export H3_LAMBDA_ARN='{h3_lambda_arn}'"
|
|
312
|
+
export_dbt_suffix = f"export DBT_SUFFIX='{dbt_suffix}'\n"
|
|
313
|
+
|
|
314
|
+
with Path("set_env.sh").open("w") as f:
|
|
315
|
+
f.write(export_suffix)
|
|
316
|
+
f.write(export_bucket)
|
|
317
|
+
f.write(export_dbt_profile)
|
|
318
|
+
if not airflow:
|
|
319
|
+
f.write(export_dbt_profile_location)
|
|
320
|
+
f.write(export_h3_lambda_arn)
|
|
321
|
+
f.write(export_dbt_suffix)
|
|
322
|
+
|
|
323
|
+
@classmethod
|
|
324
|
+
def clear(cls) -> None:
|
|
325
|
+
"""Reset the singleton instance.
|
|
326
|
+
|
|
327
|
+
Use this to force restart of the class. Mainly for testing.
|
|
328
|
+
"""
|
|
329
|
+
cls._instance = None
|
|
330
|
+
|
|
331
|
+
@classmethod
|
|
332
|
+
def instance(cls) -> "Environment | None":
|
|
333
|
+
"""Return the current singleton instance, if any. Mainly for testing."""
|
|
334
|
+
return cls._instance
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""Build safe trino query."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
import textwrap
|
|
5
|
+
from datetime import UTC, date, datetime
|
|
6
|
+
from typing import ClassVar, Literal, Self
|
|
7
|
+
|
|
8
|
+
IDENTIFIER = re.compile(r"^[a-z_][a-z0-9_\$]*$")
|
|
9
|
+
DISALLOWED_CHARACTERS = re.compile(r"[;\n\r]")
|
|
10
|
+
MAX_LIST_LENGTH = 1000
|
|
11
|
+
|
|
12
|
+
Column = str
|
|
13
|
+
OrderByClause = tuple[Column, Literal["ASC", "DESC"]]
|
|
14
|
+
DatesConditions = tuple[Literal["<", ">", "=", "<=", ">="], datetime | str]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _ensure_safe_fragment(fragment: str) -> None:
|
|
18
|
+
"""Raise error for any bad chars."""
|
|
19
|
+
if DISALLOWED_CHARACTERS.search(fragment):
|
|
20
|
+
msg = "Disallowed characters in fragment"
|
|
21
|
+
raise ValueError(msg)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def validate_identifier(name: str) -> str:
|
|
25
|
+
"""Ensure a valid SQL/Athena identifier."""
|
|
26
|
+
if not IDENTIFIER.match(name):
|
|
27
|
+
msg = f"Invalid identifier: {name!r}"
|
|
28
|
+
raise ValueError(msg)
|
|
29
|
+
return name
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def validate_order_by_direction(direction: str) -> None:
|
|
33
|
+
"""Ensure that the given direction is valid for ordering by."""
|
|
34
|
+
if direction.upper() not in ("ASC", "DESC"):
|
|
35
|
+
msg = f"Invalid direction in order by clause: {direction}"
|
|
36
|
+
raise ValueError(msg)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def quote_literal(value: None | str | float | datetime | date) -> str:
|
|
40
|
+
"""Safely quote literal values for SQL."""
|
|
41
|
+
if isinstance(value, bool):
|
|
42
|
+
return "TRUE" if value else "FALSE"
|
|
43
|
+
if value is None:
|
|
44
|
+
return "NULL"
|
|
45
|
+
if isinstance(value, (int, float)):
|
|
46
|
+
return str(value)
|
|
47
|
+
if isinstance(value, str):
|
|
48
|
+
_ensure_safe_fragment(value)
|
|
49
|
+
escaped = value.replace("'", "''")
|
|
50
|
+
return f"'{escaped}'"
|
|
51
|
+
if isinstance(value, (date, datetime)):
|
|
52
|
+
return f"cast('{value.strftime('%Y-%m-%d %H:%M:%S')}' as timestamp(6))"
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def quote_list_literal(values: list[str | int | float] | tuple[str | int | float]) -> str:
|
|
57
|
+
"""Quote a list into trino."""
|
|
58
|
+
vals = list(values)
|
|
59
|
+
if len(vals) > MAX_LIST_LENGTH:
|
|
60
|
+
msg = "Too many items in list"
|
|
61
|
+
raise ValueError(msg)
|
|
62
|
+
return "(" + ", ".join(quote_literal(v) for v in vals) + ")"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class Query:
|
|
66
|
+
"""A safe Athena SQL query builder."""
|
|
67
|
+
|
|
68
|
+
keywords: ClassVar[list[str]] = [
|
|
69
|
+
"WITH",
|
|
70
|
+
"SELECT",
|
|
71
|
+
"FROM",
|
|
72
|
+
"WHERE",
|
|
73
|
+
"GROUP BY",
|
|
74
|
+
"HAVING",
|
|
75
|
+
"ORDER BY",
|
|
76
|
+
"LIMIT",
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
def __init__(self) -> None:
|
|
80
|
+
"""Init function."""
|
|
81
|
+
self.parts: dict[str, list[str]] = {kw: [] for kw in self.keywords}
|
|
82
|
+
|
|
83
|
+
def SELECT(self, *columns: str, distinct: bool = False) -> Self: # noqa: N802
|
|
84
|
+
"""Wrap selected columns."""
|
|
85
|
+
for c in columns:
|
|
86
|
+
if c != "*":
|
|
87
|
+
for part in c.split(","):
|
|
88
|
+
validate_identifier(part)
|
|
89
|
+
col_list = [f'"{c}"' if c != "*" else "*" for c in columns]
|
|
90
|
+
if distinct:
|
|
91
|
+
col_list[0] = "DISTINCT " + col_list[0]
|
|
92
|
+
self.parts["SELECT"].extend(col_list)
|
|
93
|
+
return self
|
|
94
|
+
|
|
95
|
+
def FROM(self, database: str, table: str) -> Self: # noqa: N802
|
|
96
|
+
"""From database name table name."""
|
|
97
|
+
validate_identifier(database)
|
|
98
|
+
validate_identifier(table)
|
|
99
|
+
self.parts["FROM"].append(f'"{database}"."{table}"')
|
|
100
|
+
return self
|
|
101
|
+
|
|
102
|
+
def WHERE( # noqa: N802
|
|
103
|
+
self, *, unquote: bool = False, **conditions: str | list[str | int | float] | tuple[str | int | float]
|
|
104
|
+
) -> Self:
|
|
105
|
+
"""Where col=value style."""
|
|
106
|
+
for col, val in conditions.items():
|
|
107
|
+
validate_identifier(col)
|
|
108
|
+
if isinstance(val, (list, tuple)):
|
|
109
|
+
expr = f'"{col}" IN {val}' if unquote else f'"{col}" IN {quote_list_literal(val)}'
|
|
110
|
+
else:
|
|
111
|
+
expr = f'"{col}" = {val}' if unquote else f'"{col}" = {quote_literal(val)}'
|
|
112
|
+
self.parts["WHERE"].append(expr)
|
|
113
|
+
return self
|
|
114
|
+
|
|
115
|
+
def WHERE_LIKE( # noqa: N802
|
|
116
|
+
self,
|
|
117
|
+
field_wrapper: Literal["", "upper", "lower"] = "",
|
|
118
|
+
connector: Literal["", "OR", "AND"] = "",
|
|
119
|
+
**conditions: str,
|
|
120
|
+
) -> Self:
|
|
121
|
+
"""Where Like filter."""
|
|
122
|
+
for col, val in conditions.items():
|
|
123
|
+
validate_identifier(col)
|
|
124
|
+
if isinstance(val, list):
|
|
125
|
+
range_exprs = [
|
|
126
|
+
f'{field_wrapper}("{col}") LIKE {field_wrapper}({quote_literal("%" + str(option) + "%")})'
|
|
127
|
+
for option in val
|
|
128
|
+
]
|
|
129
|
+
expr = "(" + f" {connector} ".join(range_exprs) + ")"
|
|
130
|
+
else:
|
|
131
|
+
expr = f'{field_wrapper}("{col}") LIKE {field_wrapper}({quote_literal("%" + str(val) + "%")})'
|
|
132
|
+
self.parts["WHERE"].append(expr)
|
|
133
|
+
return self
|
|
134
|
+
|
|
135
|
+
def DATES(self, **conditions: DatesConditions) -> Self: # noqa: N802
|
|
136
|
+
"""Add specific date filtering
|
|
137
|
+
Accepts conditions where the key is the column name and value is a tuple: (operator, date_string)
|
|
138
|
+
Intended to be used to constructing a bounding range of dates rather than dates between two.
|
|
139
|
+
For dates between two values, see DATE_RANGES.
|
|
140
|
+
|
|
141
|
+
Example:
|
|
142
|
+
-------
|
|
143
|
+
- QUERY().DATES(field_1=('<', 2023-01-01))
|
|
144
|
+
|
|
145
|
+
""" # noqa: D205
|
|
146
|
+
valid_operators = {"<", ">", "=", "<=", ">="}
|
|
147
|
+
expressions = []
|
|
148
|
+
for col, (operator, date_lit) in conditions.items():
|
|
149
|
+
# Validation
|
|
150
|
+
if operator not in valid_operators:
|
|
151
|
+
err_txt = f"Invalid operator: {operator}"
|
|
152
|
+
raise ValueError(err_txt)
|
|
153
|
+
validate_identifier(col)
|
|
154
|
+
|
|
155
|
+
if isinstance(date_lit, str):
|
|
156
|
+
dt = datetime.strptime(date_lit, "%Y-%m-%d %H:%M:%S").astimezone(UTC)
|
|
157
|
+
dt.replace(tzinfo=UTC)
|
|
158
|
+
expr = f'"{col}" {operator} {quote_literal(dt)}'
|
|
159
|
+
elif isinstance(date_lit, datetime):
|
|
160
|
+
date_lit.replace(tzinfo=UTC)
|
|
161
|
+
expr = f'"{col}" {operator} {quote_literal(date_lit)}'
|
|
162
|
+
expressions.append(expr)
|
|
163
|
+
group_expr = "(" + " AND ".join(expressions) + ")"
|
|
164
|
+
|
|
165
|
+
if not expressions:
|
|
166
|
+
self.parts["WHERE"].append("")
|
|
167
|
+
else:
|
|
168
|
+
self.parts["WHERE"].append(group_expr)
|
|
169
|
+
return self
|
|
170
|
+
|
|
171
|
+
def DATE_RANGES(self, cols: list[str], start: datetime, end: datetime) -> Self: # noqa: N802
|
|
172
|
+
"""Add a grouped OR condition for date ranges: (col1 BETWEEN ... OR col2 BETWEEN ...)."""
|
|
173
|
+
for c in cols:
|
|
174
|
+
validate_identifier(c)
|
|
175
|
+
|
|
176
|
+
start_lit = quote_literal(start)
|
|
177
|
+
end_lit = quote_literal(end)
|
|
178
|
+
|
|
179
|
+
range_exprs = [f'"{c}" BETWEEN {start_lit} AND {end_lit}' for c in cols]
|
|
180
|
+
group_expr = "(" + " OR ".join(range_exprs) + ")"
|
|
181
|
+
self.parts["WHERE"].append(group_expr)
|
|
182
|
+
return self
|
|
183
|
+
|
|
184
|
+
def ORDER_BY(self, *order_by_clauses: str | OrderByClause) -> Self: # noqa: N802
|
|
185
|
+
"""Order by action."""
|
|
186
|
+
for clause in order_by_clauses:
|
|
187
|
+
if isinstance(clause, tuple):
|
|
188
|
+
col, direction = clause
|
|
189
|
+
for part in col.split("."):
|
|
190
|
+
validate_identifier(part)
|
|
191
|
+
validate_order_by_direction(direction)
|
|
192
|
+
self.parts["ORDER BY"].append(f'"{col}" {direction}')
|
|
193
|
+
else:
|
|
194
|
+
# Allow ordering by columns without direction
|
|
195
|
+
col = clause
|
|
196
|
+
for part in col.split("."):
|
|
197
|
+
validate_identifier(part)
|
|
198
|
+
self.parts["ORDER BY"].append(f'"{col}"')
|
|
199
|
+
|
|
200
|
+
return self
|
|
201
|
+
|
|
202
|
+
def LIMIT(self, n: int) -> Self: # noqa: N802
|
|
203
|
+
"""Limit query."""
|
|
204
|
+
if not isinstance(n, int) or n < 0:
|
|
205
|
+
msg = "LIMIT must be a non-negative integer"
|
|
206
|
+
raise ValueError(msg)
|
|
207
|
+
self.parts["LIMIT"].append(str(n))
|
|
208
|
+
return self
|
|
209
|
+
|
|
210
|
+
def __str__(self) -> str:
|
|
211
|
+
"""Create query."""
|
|
212
|
+
sql_lines = []
|
|
213
|
+
for kw in self.keywords:
|
|
214
|
+
vals = self.parts[kw]
|
|
215
|
+
if not vals:
|
|
216
|
+
continue
|
|
217
|
+
match kw:
|
|
218
|
+
case "SELECT" | "ORDER BY" | "GROUP BY":
|
|
219
|
+
sql_lines.append(f"{kw} " + ", ".join(vals))
|
|
220
|
+
case "FROM":
|
|
221
|
+
sql_lines.append(f"{kw} " + ", ".join(vals))
|
|
222
|
+
case "WHERE" | "HAVING":
|
|
223
|
+
sql_lines.append(f"{kw} " + " AND ".join(vals))
|
|
224
|
+
case "LIMIT":
|
|
225
|
+
sql_lines.append(f"{kw} {vals[0]}")
|
|
226
|
+
case _:
|
|
227
|
+
sql_lines.append(f"{kw} {' '.join(vals)}")
|
|
228
|
+
return "\n".join(sql_lines)
|
|
229
|
+
|
|
230
|
+
def pretty(self) -> str:
|
|
231
|
+
"""Make SQL look nice."""
|
|
232
|
+
return textwrap.indent(str(self), " ")
|