peak-sdk 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- peak/__init__.py +36 -0
- peak/_version.py +21 -0
- peak/auth.py +22 -0
- peak/base_client.py +52 -0
- peak/cli/__init_.py +20 -0
- peak/cli/args.py +84 -0
- peak/cli/cli.py +56 -0
- peak/cli/helpers.py +187 -0
- peak/cli/press/__init__.py +21 -0
- peak/cli/press/apps/__init__.py +40 -0
- peak/cli/press/apps/deployments.py +238 -0
- peak/cli/press/apps/specs.py +387 -0
- peak/cli/press/blocks/__init__.py +40 -0
- peak/cli/press/blocks/deployments.py +240 -0
- peak/cli/press/blocks/specs.py +492 -0
- peak/cli/press/deployments.py +78 -0
- peak/cli/press/specs.py +131 -0
- peak/cli/resources/__init__.py +21 -0
- peak/cli/resources/artifacts.py +310 -0
- peak/cli/resources/images.py +886 -0
- peak/cli/resources/webapps.py +356 -0
- peak/cli/resources/workflows.py +703 -0
- peak/cli/ruff.toml +11 -0
- peak/cli/version.py +49 -0
- peak/compression.py +162 -0
- peak/config.py +24 -0
- peak/constants.py +105 -0
- peak/exceptions.py +217 -0
- peak/handler.py +358 -0
- peak/helpers.py +184 -0
- peak/logger.py +48 -0
- peak/press/__init__.py +28 -0
- peak/press/apps.py +669 -0
- peak/press/blocks.py +707 -0
- peak/press/deployments.py +145 -0
- peak/press/specs.py +260 -0
- peak/py.typed +0 -0
- peak/resources/__init__.py +28 -0
- peak/resources/artifacts.py +343 -0
- peak/resources/images.py +675 -0
- peak/resources/webapps.py +278 -0
- peak/resources/workflows.py +625 -0
- peak/session.py +259 -0
- peak/telemetry.py +201 -0
- peak/template.py +231 -0
- peak/validators.py +48 -0
- peak_sdk-1.0.0.dist-info/LICENSE +201 -0
- peak_sdk-1.0.0.dist-info/METADATA +199 -0
- peak_sdk-1.0.0.dist-info/RECORD +51 -0
- peak_sdk-1.0.0.dist-info/WHEEL +4 -0
- peak_sdk-1.0.0.dist-info/entry_points.txt +3 -0
peak/session.py
ADDED
@@ -0,0 +1,259 @@
|
|
1
|
+
#
|
2
|
+
# # Copyright © 2023 Peak AI Limited. or its affiliates. All Rights Reserved.
|
3
|
+
# #
|
4
|
+
# # Licensed under the Apache License, Version 2.0 (the "License"). You
|
5
|
+
# # may not use this file except in compliance with the License. A copy of
|
6
|
+
# # the License is located at:
|
7
|
+
# #
|
8
|
+
# # https://github.com/PeakBI/peak-sdk/blob/main/LICENSE
|
9
|
+
# #
|
10
|
+
# # or in the "license" file accompanying this file. This file is
|
11
|
+
# # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
12
|
+
# # ANY KIND, either express or implied. See the License for the specific
|
13
|
+
# # language governing permissions and limitations under the License.
|
14
|
+
# #
|
15
|
+
# # This file is part of the peak-sdk.
|
16
|
+
# # see (https://github.com/PeakBI/peak-sdk)
|
17
|
+
# #
|
18
|
+
# # You should have received a copy of the APACHE LICENSE, VERSION 2.0
|
19
|
+
# # along with this program. If not, see <https://apache.org/licenses/LICENSE-2.0>
|
20
|
+
#
|
21
|
+
"""Session module for Peak API."""
|
22
|
+
from __future__ import annotations
|
23
|
+
|
24
|
+
import os
|
25
|
+
from pathlib import Path
|
26
|
+
from typing import Any, Dict, Iterator, List, Optional
|
27
|
+
|
28
|
+
from peak import exceptions
|
29
|
+
from peak.constants import DOWNLOAD_CHUNK_SIZE, ContentType, HttpMethods, Stage
|
30
|
+
from peak.handler import Handler
|
31
|
+
from peak.helpers import get_base_domain
|
32
|
+
from peak.logger import logger
|
33
|
+
|
34
|
+
DEFAULT_SESSION = None
|
35
|
+
|
36
|
+
|
37
|
+
def _get_default_session() -> Session:
|
38
|
+
"""Get the global default session object. Creates one if not already created and re-uses it.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
Session: The default session object
|
42
|
+
"""
|
43
|
+
global DEFAULT_SESSION # noqa: PLW0603
|
44
|
+
if DEFAULT_SESSION is None:
|
45
|
+
logger.debug("Creating DEFAULT_SESSION object")
|
46
|
+
DEFAULT_SESSION = Session()
|
47
|
+
|
48
|
+
logger.debug("DEFAULT_SESSION already present, reusing the object")
|
49
|
+
return DEFAULT_SESSION
|
50
|
+
|
51
|
+
|
52
|
+
class Session:
|
53
|
+
"""A session stores credentials which are used to authenticate the requests.
|
54
|
+
|
55
|
+
By default, a DEFAULT_SESSION is created which reads the credentials from the env variables.
|
56
|
+
Custom Session objects can be created and used if you want to work with multiple tenants.
|
57
|
+
"""
|
58
|
+
|
59
|
+
auth_token: str
|
60
|
+
stage: Stage
|
61
|
+
base_domain: str
|
62
|
+
handler: Handler
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
auth_token: Optional[str] = None,
|
67
|
+
stage: Optional[str] = None,
|
68
|
+
) -> None:
|
69
|
+
"""Initialize a session for the Peak API.
|
70
|
+
|
71
|
+
Args:
|
72
|
+
auth_token (str | None): Authentication token. Both API Key and Bearer tokens are supported.
|
73
|
+
Picks up from `API_KEY` environment variable if not provided.
|
74
|
+
stage (str | None): Name of the stage where tenant is created. Default is `prod`.
|
75
|
+
"""
|
76
|
+
self.base_domain: str = ""
|
77
|
+
self._set_auth_token(auth_token)
|
78
|
+
self._set_stage(stage)
|
79
|
+
self._set_base_domain()
|
80
|
+
self.handler = Handler()
|
81
|
+
|
82
|
+
def create_request(
|
83
|
+
self,
|
84
|
+
endpoint: str,
|
85
|
+
method: HttpMethods,
|
86
|
+
content_type: ContentType,
|
87
|
+
*,
|
88
|
+
params: Optional[Dict[str, Any]] = None,
|
89
|
+
body: Optional[Dict[str, Any]] = None,
|
90
|
+
path: Optional[str] = None,
|
91
|
+
ignore_files: Optional[list[str]] = None,
|
92
|
+
subdomain: Optional[str] = "service",
|
93
|
+
) -> Any:
|
94
|
+
"""Prepares a request to be sent over the network.
|
95
|
+
|
96
|
+
Adds auth_token to the headers and creates URL using STAGE.
|
97
|
+
To be used with endpoints which returns JSON parsable response.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
endpoint (str): The endpoint to send the request to.
|
101
|
+
method (HttpMethods): The HTTP method to use.
|
102
|
+
content_type (ContentType): The content type of the request.
|
103
|
+
params (Dict[str, Any], optional): params to send to the request, defaults to None
|
104
|
+
body (Dict[str, Any], optional): body to send to the request, defaults to None
|
105
|
+
path (Optional[str] optional): path to the file or folder that will be compressed and used as artifact, required for multipart requests.
|
106
|
+
ignore_files(Optional[list[str]]): Ignore files to be used when creating artifact, used only for multipart requests.
|
107
|
+
subdomain (Optional[str]): Subdomain for the endpoint. Defaults to `service`.
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
Any: response dict object.
|
111
|
+
"""
|
112
|
+
headers: Dict[str, str] = {"Authorization": self.auth_token}
|
113
|
+
base_domain: str = get_base_domain(stage=self.stage.value, subdomain=subdomain)
|
114
|
+
url: str = f"{base_domain}/{endpoint}"
|
115
|
+
return self.handler.make_request(
|
116
|
+
url,
|
117
|
+
method,
|
118
|
+
content_type=content_type,
|
119
|
+
headers=headers,
|
120
|
+
params=params or {},
|
121
|
+
body=body or {},
|
122
|
+
path=path,
|
123
|
+
ignore_files=ignore_files,
|
124
|
+
session_meta={
|
125
|
+
"stage": self.stage,
|
126
|
+
},
|
127
|
+
).json()
|
128
|
+
|
129
|
+
def create_generator_request(
|
130
|
+
self,
|
131
|
+
endpoint: str,
|
132
|
+
method: HttpMethods,
|
133
|
+
content_type: ContentType,
|
134
|
+
response_key: str,
|
135
|
+
*,
|
136
|
+
params: Optional[Dict[str, Any]] = None,
|
137
|
+
body: Optional[Dict[str, Any]] = None,
|
138
|
+
path: Optional[str] = None,
|
139
|
+
subdomain: Optional[str] = "service",
|
140
|
+
) -> Iterator[Dict[str, Any]]:
|
141
|
+
"""Prepares a request to be sent over the network.
|
142
|
+
|
143
|
+
Adds auth_token to the headers and creates URL using STAGE.
|
144
|
+
Returns an iterator which automatically handles pagination and returns a new page at each iteration.
|
145
|
+
To be used with list endpoints only, which returns `pageNumber`, `pageCount` keys in response.
|
146
|
+
|
147
|
+
# noqa: DAR201
|
148
|
+
|
149
|
+
Args:
|
150
|
+
endpoint (str): The endpoint to send the request to.
|
151
|
+
method (HttpMethods): The HTTP method to use.
|
152
|
+
content_type (ContentType): The content type of the request.
|
153
|
+
response_key (str): key in the response dict which contains actual list data.
|
154
|
+
params (Optional[Dict[str, Any]]): params to send to the request.
|
155
|
+
body (Optional[Dict[str, Any]]): body to send to the request.
|
156
|
+
path (Optional[str]): path to the file or folder that will be compressed and used as artifact.
|
157
|
+
subdomain (Optional[str]): Subdomain for the endpoint. Defaults to `service`.
|
158
|
+
|
159
|
+
Yields:
|
160
|
+
Iterator[Dict[str, Any]]: paginated response json, element wise.
|
161
|
+
|
162
|
+
Raises:
|
163
|
+
StopIteration: There are no more pages to list
|
164
|
+
"""
|
165
|
+
page_number: int = 1
|
166
|
+
page_count: int = 1
|
167
|
+
params = params or {}
|
168
|
+
while page_number <= page_count:
|
169
|
+
params = {**params, "pageNumber": page_number}
|
170
|
+
response = self.create_request(
|
171
|
+
endpoint,
|
172
|
+
method,
|
173
|
+
content_type,
|
174
|
+
params=params,
|
175
|
+
body=body,
|
176
|
+
path=path,
|
177
|
+
subdomain=subdomain,
|
178
|
+
)
|
179
|
+
page_count = response["pageCount"]
|
180
|
+
yield from response[response_key]
|
181
|
+
page_number += 1
|
182
|
+
return f"No more {response_key} to list"
|
183
|
+
|
184
|
+
def create_download_request(
|
185
|
+
self,
|
186
|
+
endpoint: str,
|
187
|
+
method: HttpMethods,
|
188
|
+
content_type: ContentType,
|
189
|
+
download_path: str,
|
190
|
+
params: Optional[Dict[str, Any]] = None,
|
191
|
+
body: Optional[Dict[str, Any]] = None,
|
192
|
+
) -> None:
|
193
|
+
"""Prepares a request to be sent over the network.
|
194
|
+
|
195
|
+
Adds auth_token to the headers and creates URL using STAGE.
|
196
|
+
To be used for file download requests.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
endpoint (str): The endpoint to send the request to.
|
200
|
+
method (HttpMethods): The HTTP method to use.
|
201
|
+
content_type (ContentType): The content type of the request.
|
202
|
+
download_path (str): Path where the downloaded file will be stored.
|
203
|
+
params (Dict[str, Any], optional): params to send to the request, defaults to None
|
204
|
+
body (Dict[str, Any], optional): body to send to the request, only used in multipart requests, defaults to None
|
205
|
+
|
206
|
+
Raises:
|
207
|
+
InvalidPathException: The download_path is invalid.
|
208
|
+
"""
|
209
|
+
headers: Dict[str, str] = {"Authorization": self.auth_token}
|
210
|
+
url: str = f"{self.base_domain}/{endpoint}"
|
211
|
+
response: Any = self.handler.make_request(
|
212
|
+
url,
|
213
|
+
method,
|
214
|
+
content_type=content_type,
|
215
|
+
headers=headers,
|
216
|
+
params=params or {},
|
217
|
+
body=body or {},
|
218
|
+
request_kwargs={
|
219
|
+
"stream": True,
|
220
|
+
"allow_redirects": True,
|
221
|
+
},
|
222
|
+
)
|
223
|
+
try:
|
224
|
+
with Path(download_path).open("wb") as fd:
|
225
|
+
for chunk in response.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
|
226
|
+
fd.write(chunk)
|
227
|
+
except IsADirectoryError:
|
228
|
+
raise exceptions.InvalidPathException(
|
229
|
+
download_path,
|
230
|
+
"Path should include the name with which the downloaded file should be stored.",
|
231
|
+
) from None
|
232
|
+
|
233
|
+
def _set_base_domain(self) -> None:
|
234
|
+
self.base_domain = get_base_domain(stage=self.stage.value)
|
235
|
+
|
236
|
+
def _set_auth_token(self, auth_token: Optional[str]) -> None:
|
237
|
+
if auth_token is not None:
|
238
|
+
self.auth_token = auth_token
|
239
|
+
return
|
240
|
+
|
241
|
+
logger.info("auth_token not given, searching for API_KEY in env variables")
|
242
|
+
if not os.environ.get("API_KEY"):
|
243
|
+
raise exceptions.MissingEnvironmentVariableException(env_var="API_KEY")
|
244
|
+
self.auth_token = os.environ["API_KEY"]
|
245
|
+
|
246
|
+
def _set_stage(self, stage: Optional[str]) -> None:
|
247
|
+
if stage is not None:
|
248
|
+
self.stage = Stage(stage)
|
249
|
+
return
|
250
|
+
|
251
|
+
logger.info("stage not given, searching for STAGE in env variables")
|
252
|
+
if not os.environ.get("STAGE"):
|
253
|
+
logger.info("STAGE environment variable is not set, defaulting to PROD")
|
254
|
+
self.stage = Stage.PROD
|
255
|
+
return
|
256
|
+
self.stage = Stage(os.environ["STAGE"])
|
257
|
+
|
258
|
+
|
259
|
+
__all__: List[str] = ["Session", "_get_default_session"]
|
peak/telemetry.py
ADDED
@@ -0,0 +1,201 @@
|
|
1
|
+
#
|
2
|
+
# # Copyright © 2023 Peak AI Limited. or its affiliates. All Rights Reserved.
|
3
|
+
# #
|
4
|
+
# # Licensed under the Apache License, Version 2.0 (the "License"). You
|
5
|
+
# # may not use this file except in compliance with the License. A copy of
|
6
|
+
# # the License is located at:
|
7
|
+
# #
|
8
|
+
# # https://github.com/PeakBI/peak-sdk/blob/main/LICENSE
|
9
|
+
# #
|
10
|
+
# # or in the "license" file accompanying this file. This file is
|
11
|
+
# # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
12
|
+
# # ANY KIND, either express or implied. See the License for the specific
|
13
|
+
# # language governing permissions and limitations under the License.
|
14
|
+
# #
|
15
|
+
# # This file is part of the peak-sdk.
|
16
|
+
# # see (https://github.com/PeakBI/peak-sdk)
|
17
|
+
# #
|
18
|
+
# # You should have received a copy of the APACHE LICENSE, VERSION 2.0
|
19
|
+
# # along with this program. If not, see <https://apache.org/licenses/LICENSE-2.0>
|
20
|
+
#
|
21
|
+
"""Decorator for sending telemetry data for each request."""
|
22
|
+
from __future__ import annotations
|
23
|
+
|
24
|
+
import platform
|
25
|
+
import threading
|
26
|
+
import uuid
|
27
|
+
from contextlib import suppress
|
28
|
+
from functools import wraps
|
29
|
+
from typing import Any, Callable, Dict, Optional
|
30
|
+
|
31
|
+
import requests
|
32
|
+
|
33
|
+
import peak.config
|
34
|
+
from peak.constants import ContentType, HttpMethods, Stage
|
35
|
+
from peak.exceptions import BaseHttpException
|
36
|
+
from peak.helpers import get_base_domain
|
37
|
+
|
38
|
+
from ._version import __version__
|
39
|
+
|
40
|
+
F = Callable[..., requests.Response]
|
41
|
+
session_id = str(uuid.uuid4())
|
42
|
+
|
43
|
+
|
44
|
+
def get_status_code(error: Optional[Exception]) -> int | None:
|
45
|
+
"""It takes an exception object and returns the status code associated with it.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
error (Optional[Exception]): The exception object to check
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
int | None: The status code related to the error or None if no status code found
|
52
|
+
"""
|
53
|
+
if not error:
|
54
|
+
return 200
|
55
|
+
|
56
|
+
if isinstance(error, BaseHttpException):
|
57
|
+
return error.STATUS_CODE
|
58
|
+
|
59
|
+
return None
|
60
|
+
|
61
|
+
|
62
|
+
def telemetry(make_request: F) -> F:
|
63
|
+
"""A decorator that wraps over the make_request function to send telemetry requests as required.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
make_request (F): The make_request function to wrap in the decorator
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
F: the wrapped function that sends telemetry data for each request
|
70
|
+
"""
|
71
|
+
|
72
|
+
def get_telemetry_url(session_meta: Optional[Dict[str, Any]] = None) -> str:
|
73
|
+
"""Returns the telemetry url for the given stage.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
session_meta (Optional[Dict[str, Any]]): Session metadata object that contains information like stage
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
str: The telemetry URL
|
80
|
+
"""
|
81
|
+
stage = Stage.PROD
|
82
|
+
if session_meta:
|
83
|
+
stage = session_meta["stage"] if "stage" in session_meta else Stage.PROD
|
84
|
+
base_domain = get_base_domain(stage.value, "service")
|
85
|
+
return f"{base_domain}/resource-usage/api/v1/telemetry"
|
86
|
+
|
87
|
+
def get_telemetry_data() -> Dict[str, Any]:
|
88
|
+
return {
|
89
|
+
"sdkVersion": __version__,
|
90
|
+
"os": platform.platform(),
|
91
|
+
"hostname": platform.uname().node,
|
92
|
+
"pythonVersion": platform.python_version(),
|
93
|
+
"sessionId": session_id,
|
94
|
+
"requestId": str(uuid.uuid4()),
|
95
|
+
}
|
96
|
+
|
97
|
+
@wraps(make_request)
|
98
|
+
def wrapper(
|
99
|
+
self: Any,
|
100
|
+
url: str,
|
101
|
+
method: HttpMethods,
|
102
|
+
content_type: ContentType,
|
103
|
+
headers: Optional[Dict[str, str]] = None,
|
104
|
+
params: Optional[Dict[str, Any]] = None,
|
105
|
+
body: Optional[Dict[str, Any]] = None,
|
106
|
+
path: Optional[str] = None,
|
107
|
+
request_kwargs: Optional[Dict[str, int | bool | str | float]] = None,
|
108
|
+
ignore_files: Optional[list[str]] = None,
|
109
|
+
session_meta: Optional[Dict[str, Any]] = None,
|
110
|
+
) -> requests.Response:
|
111
|
+
"""A decorator that wraps over the make_request function to send telemetry requests as required.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
self (Any): the object instance of Handler class on which the make_request call is being made
|
115
|
+
url (str): url to send the request to
|
116
|
+
method (HttpMethods): The HTTP method to use, e.g. get, post, put, delete
|
117
|
+
content_type (ContentType): content type of the request
|
118
|
+
headers (Dict[str, str]): headers to send with the request
|
119
|
+
params (Dict[str, Any]): params to send to the request
|
120
|
+
body (Dict[str, Any]): body to send to the request
|
121
|
+
path (str): path to the file or folder that will be compressed and used as artifact, defaults to None
|
122
|
+
request_kwargs(Dict[str, int | bool | str | float] | None): extra arguments to be passed when making the request.
|
123
|
+
ignore_files(Optional[list[str]]): Ignore files to be used when creating artifact
|
124
|
+
session_meta(Dict[str, Any]): Metadata about the session object, like - stage
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
requests.Response: response json
|
128
|
+
|
129
|
+
Raises:
|
130
|
+
BaseHttpException: The http request failed.
|
131
|
+
Exception: Some other error occurred.
|
132
|
+
"""
|
133
|
+
|
134
|
+
def make_telemetry_request(
|
135
|
+
res: Optional[requests.Response] = None,
|
136
|
+
error: Optional[Exception] = None,
|
137
|
+
) -> None:
|
138
|
+
telemetry_url = get_telemetry_url(session_meta)
|
139
|
+
|
140
|
+
telemetry_body = {
|
141
|
+
"response": res.json() if method != HttpMethods.GET and res else None,
|
142
|
+
"error": str(error) if error else None,
|
143
|
+
"url": url,
|
144
|
+
"requestMethod": method.value,
|
145
|
+
"statusCode": get_status_code(error),
|
146
|
+
"source": peak.config.SOURCE.value,
|
147
|
+
**get_telemetry_data(),
|
148
|
+
}
|
149
|
+
|
150
|
+
with suppress(Exception):
|
151
|
+
make_request(
|
152
|
+
self,
|
153
|
+
telemetry_url,
|
154
|
+
HttpMethods.POST,
|
155
|
+
ContentType.APPLICATION_JSON,
|
156
|
+
headers=headers,
|
157
|
+
body=telemetry_body,
|
158
|
+
)
|
159
|
+
|
160
|
+
def trigger_usage_collection(
|
161
|
+
res: Optional[Any] = None,
|
162
|
+
error: Optional[Exception] = None,
|
163
|
+
) -> None:
|
164
|
+
thr = threading.Thread(
|
165
|
+
target=make_telemetry_request,
|
166
|
+
kwargs={
|
167
|
+
"res": res,
|
168
|
+
"error": error,
|
169
|
+
},
|
170
|
+
)
|
171
|
+
|
172
|
+
thr.start()
|
173
|
+
|
174
|
+
try:
|
175
|
+
custom_headers = {f"x-peak-{key}": value for (key, value) in get_telemetry_data().items()}
|
176
|
+
custom_headers = {
|
177
|
+
**custom_headers,
|
178
|
+
**(headers or {}),
|
179
|
+
}
|
180
|
+
|
181
|
+
res = make_request(
|
182
|
+
self,
|
183
|
+
url,
|
184
|
+
method,
|
185
|
+
content_type=content_type,
|
186
|
+
headers=custom_headers,
|
187
|
+
params=params or {},
|
188
|
+
body=body or {},
|
189
|
+
path=path,
|
190
|
+
ignore_files=ignore_files,
|
191
|
+
request_kwargs=request_kwargs,
|
192
|
+
)
|
193
|
+
|
194
|
+
trigger_usage_collection(res=res)
|
195
|
+
except Exception as e:
|
196
|
+
trigger_usage_collection(error=e)
|
197
|
+
raise
|
198
|
+
else:
|
199
|
+
return res
|
200
|
+
|
201
|
+
return wrapper
|
peak/template.py
ADDED
@@ -0,0 +1,231 @@
|
|
1
|
+
#
|
2
|
+
# # Copyright © 2023 Peak AI Limited. or its affiliates. All Rights Reserved.
|
3
|
+
# #
|
4
|
+
# # Licensed under the Apache License, Version 2.0 (the "License"). You
|
5
|
+
# # may not use this file except in compliance with the License. A copy of
|
6
|
+
# # the License is located at:
|
7
|
+
# #
|
8
|
+
# # https://github.com/PeakBI/peak-sdk/blob/main/LICENSE
|
9
|
+
# #
|
10
|
+
# # or in the "license" file accompanying this file. This file is
|
11
|
+
# # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
12
|
+
# # ANY KIND, either express or implied. See the License for the specific
|
13
|
+
# # language governing permissions and limitations under the License.
|
14
|
+
# #
|
15
|
+
# # This file is part of the peak-sdk.
|
16
|
+
# # see (https://github.com/PeakBI/peak-sdk)
|
17
|
+
# #
|
18
|
+
# # You should have received a copy of the APACHE LICENSE, VERSION 2.0
|
19
|
+
# # along with this program. If not, see <https://apache.org/licenses/LICENSE-2.0>
|
20
|
+
#
|
21
|
+
"""Template module which handles all things related to templates."""
|
22
|
+
from __future__ import annotations
|
23
|
+
|
24
|
+
import os
|
25
|
+
import re
|
26
|
+
from pathlib import Path
|
27
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
28
|
+
|
29
|
+
import jinja2
|
30
|
+
import yaml
|
31
|
+
from jinja2 import Environment
|
32
|
+
from jinja2.ext import Extension
|
33
|
+
|
34
|
+
from peak import exceptions
|
35
|
+
|
36
|
+
|
37
|
+
def _parse_jinja_template(template_path: Path, params: Dict[str, Any]) -> str:
|
38
|
+
"""Read, parse and render the Jinja template text."""
|
39
|
+
jinja_loader = _CustomJinjaLoader()
|
40
|
+
jinja_env = jinja2.Environment( # TODO: show warning if variable not found in params # noqa: TD002, TD003, RUF100
|
41
|
+
loader=jinja_loader,
|
42
|
+
autoescape=False, # noqa: S701
|
43
|
+
extensions=[_IncludeWithIndentation],
|
44
|
+
)
|
45
|
+
jinja_template: jinja2.Template = jinja_env.get_template(str(template_path))
|
46
|
+
return jinja_template.render(params, env=os.environ)
|
47
|
+
|
48
|
+
|
49
|
+
def load_template(file: Union[Path, str], params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
50
|
+
"""Load a template file through `Jinja` into a dictionary.
|
51
|
+
|
52
|
+
This function performs the following steps:
|
53
|
+
* Passes the `YAML` file to be loaded and parsed through **`Jinja`**
|
54
|
+
* **`Jinja`** substitutes the variables with their values as they are found in `params`
|
55
|
+
* Loads any other files that are referenced using the Jinja `{% include %}` directive, if it is present.
|
56
|
+
* Updates the `context` key path within `image` definitions with respect to its relative parent file path.
|
57
|
+
* Loads the rendered YAML file into a `dictionary`.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
file (Union[Path, str]): Path to the templated `YAML` file to be loaded.
|
61
|
+
params (Dict[str, Any] | None, optional): Named parameters to be passed to Jinja. Defaults to `{}`.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
Dict[str, Any]: Dictionary containing the rendered YAML file
|
65
|
+
"""
|
66
|
+
params = {} if params is None else params
|
67
|
+
file = Path(file)
|
68
|
+
template: str = _parse_jinja_template(file, params)
|
69
|
+
return yaml.safe_load(template) # type: ignore[no-any-return]
|
70
|
+
|
71
|
+
|
72
|
+
class _CustomJinjaLoader(jinja2.BaseLoader):
|
73
|
+
"""Custom Jinja loader class which handles the include directive.
|
74
|
+
|
75
|
+
Inspired from the jinja2.FileSystemLoader class.
|
76
|
+
"""
|
77
|
+
|
78
|
+
def __init__(
|
79
|
+
self,
|
80
|
+
search_path: Optional[List[str]] = None,
|
81
|
+
encoding: str = "utf-8",
|
82
|
+
) -> None:
|
83
|
+
"""Initialize all variables.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
search_path (List[str] | None, optional): Path(s) of the directory to search file in.
|
87
|
+
encoding (str): Encoding to use when reading files. Defaults to "utf-8".
|
88
|
+
"""
|
89
|
+
if search_path is None:
|
90
|
+
search_path = ["."]
|
91
|
+
|
92
|
+
self.search_path: List[str] = [os.fspath(p) for p in search_path]
|
93
|
+
self.encoding: str = encoding
|
94
|
+
self.root_file_path: Optional[str] = None
|
95
|
+
self.build_context_regex: re.Pattern[str] = re.compile(r"^(\s*context\s*:\s+)(.+)$", flags=re.MULTILINE)
|
96
|
+
self.seen_files: Set[Path] = set()
|
97
|
+
|
98
|
+
def _update_image_build_context(self, source: str, file_parent_dir: str) -> str:
|
99
|
+
"""Updates the context key in image definition to the relative path of the file where it is being imported.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
source (str): Content of the file
|
103
|
+
file_parent_dir (str): Directory where the file is located
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
str: Content of the file with updated context value if present.
|
107
|
+
"""
|
108
|
+
if self.root_file_path is None:
|
109
|
+
return source
|
110
|
+
|
111
|
+
def substitute(match: re.Match[str]) -> str:
|
112
|
+
"""Substitute the context key with the relative path of the file where it is being imported."""
|
113
|
+
context_path: str = match.group(2)
|
114
|
+
context_path = context_path.strip().strip('"').strip("'")
|
115
|
+
final_path: str = str(Path(file_parent_dir) / Path(context_path))
|
116
|
+
final_relative_path: str = os.path.relpath(final_path, self.root_file_path)
|
117
|
+
return f"{match.group(1)}{final_relative_path}"
|
118
|
+
|
119
|
+
return self.build_context_regex.sub(substitute, source)
|
120
|
+
|
121
|
+
def get_source(self, _: jinja2.Environment, template_path: str) -> Tuple[str, str, Callable[[], bool]]:
|
122
|
+
"""Searches and reads the template file.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
_ (jinja2.Environment): Jinja environment variable.
|
126
|
+
template_path (str): Path of the template file.
|
127
|
+
|
128
|
+
# noqa: DAR401
|
129
|
+
Raises:
|
130
|
+
jinja2.TemplateNotFound: The template file is not found on given path.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
Tuple[str, str, Callable[[], bool]]: Tuple containing 3 variables
|
134
|
+
1. Content of the template file
|
135
|
+
2. Normalized path where the file was found
|
136
|
+
3. Function which checks if the file was modified (not useful in this case)
|
137
|
+
"""
|
138
|
+
template_path = template_path.strip()
|
139
|
+
for search_path in self.search_path:
|
140
|
+
# Use posixpath even on Windows to avoid "drive:" or UNC
|
141
|
+
# segments breaking out of the search directory.
|
142
|
+
file_path: str = str(Path(search_path) / Path(template_path))
|
143
|
+
|
144
|
+
if Path(file_path).is_file():
|
145
|
+
break
|
146
|
+
else:
|
147
|
+
error_msg: str = f"File does not exist at path: {template_path!r}"
|
148
|
+
raise jinja2.TemplateNotFound(error_msg)
|
149
|
+
|
150
|
+
file_path_obj = Path(file_path)
|
151
|
+
absolute_file_path: Path = file_path_obj.resolve()
|
152
|
+
|
153
|
+
if absolute_file_path in self.seen_files:
|
154
|
+
error_msg = f"Failed to render template, circular include directive found at {absolute_file_path!r}"
|
155
|
+
raise exceptions.InvalidTemplateException(error_msg)
|
156
|
+
|
157
|
+
file_parent_dir = str(file_path_obj.parent)
|
158
|
+
with Path(file_path).open(encoding=self.encoding) as f:
|
159
|
+
contents: str = f.read()
|
160
|
+
contents = self._update_image_build_context(contents, file_parent_dir)
|
161
|
+
|
162
|
+
self.seen_files.add(absolute_file_path)
|
163
|
+
self.search_path.append(file_parent_dir)
|
164
|
+
if self.root_file_path is None:
|
165
|
+
self.root_file_path = file_parent_dir
|
166
|
+
|
167
|
+
return contents, os.path.normpath(file_path), lambda: True
|
168
|
+
|
169
|
+
|
170
|
+
class _IncludeWithIndentation(Extension):
|
171
|
+
"""Override Jinja include directive to preserve indentation.
|
172
|
+
|
173
|
+
Inspired from: https://github.com/stereobutter/jinja2_workarounds
|
174
|
+
"""
|
175
|
+
|
176
|
+
@staticmethod
|
177
|
+
def _include_statement_regex(block_start: str, block_end: str) -> re.Pattern[str]:
|
178
|
+
"""Get the compiled regex for finding the include directives in template."""
|
179
|
+
return re.compile(
|
180
|
+
rf"""
|
181
|
+
(^.*)
|
182
|
+
(?=
|
183
|
+
(
|
184
|
+
{re.escape(block_start)}
|
185
|
+
(?P<block_start_modifier> [\+|-]?)
|
186
|
+
(?P<statement>
|
187
|
+
\s*include
|
188
|
+
\s+.*?
|
189
|
+
)
|
190
|
+
(?P<block_end_modifier> [\+|-]?)
|
191
|
+
{re.escape(block_end)}
|
192
|
+
)
|
193
|
+
)
|
194
|
+
.*$
|
195
|
+
""",
|
196
|
+
flags=re.MULTILINE | re.VERBOSE,
|
197
|
+
)
|
198
|
+
|
199
|
+
def preprocess(self, source: str, _: Optional[str], __: Optional[str] = None) -> str:
|
200
|
+
"""Enclose all include directives.
|
201
|
+
|
202
|
+
For all the regex matches in the text, enclose the include block
|
203
|
+
with indent filter blocks and return updated text.
|
204
|
+
"""
|
205
|
+
env: Environment = self.environment
|
206
|
+
|
207
|
+
block_start: str = env.block_start_string
|
208
|
+
block_end: str = env.block_end_string
|
209
|
+
pattern: re.Pattern[str] = self._include_statement_regex(block_start=block_start, block_end=block_end)
|
210
|
+
re.compile("\n")
|
211
|
+
|
212
|
+
def add_indentation_filter(match: re.Match[str]) -> str:
|
213
|
+
"""Add indent filter to include blocks."""
|
214
|
+
content_before_include: str | Any = match.group(1)
|
215
|
+
include_statement: str | Any = match.group("statement").replace("indent content", "")
|
216
|
+
|
217
|
+
block_start_modifier: str | Any = match.group("block_start_modifier") or ""
|
218
|
+
block_end_modifier: str | Any = match.group("block_end_modifier") or ""
|
219
|
+
|
220
|
+
start_filter: str = (
|
221
|
+
f"{block_start + block_start_modifier} filter indent({len(content_before_include)}) {block_end}"
|
222
|
+
)
|
223
|
+
include_block: str = f"{block_start} {include_statement} {block_end}"
|
224
|
+
end_filter: str = f"{block_start} endfilter {block_end_modifier + block_end}"
|
225
|
+
|
226
|
+
return f"{content_before_include}{start_filter}{include_block}{end_filter}"
|
227
|
+
|
228
|
+
return pattern.sub(add_indentation_filter, source)
|
229
|
+
|
230
|
+
|
231
|
+
__all__: List[str] = ["load_template"]
|