peak-sdk 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. peak/__init__.py +36 -0
  2. peak/_version.py +21 -0
  3. peak/auth.py +22 -0
  4. peak/base_client.py +52 -0
  5. peak/cli/__init_.py +20 -0
  6. peak/cli/args.py +84 -0
  7. peak/cli/cli.py +56 -0
  8. peak/cli/helpers.py +187 -0
  9. peak/cli/press/__init__.py +21 -0
  10. peak/cli/press/apps/__init__.py +40 -0
  11. peak/cli/press/apps/deployments.py +238 -0
  12. peak/cli/press/apps/specs.py +387 -0
  13. peak/cli/press/blocks/__init__.py +40 -0
  14. peak/cli/press/blocks/deployments.py +240 -0
  15. peak/cli/press/blocks/specs.py +492 -0
  16. peak/cli/press/deployments.py +78 -0
  17. peak/cli/press/specs.py +131 -0
  18. peak/cli/resources/__init__.py +21 -0
  19. peak/cli/resources/artifacts.py +310 -0
  20. peak/cli/resources/images.py +886 -0
  21. peak/cli/resources/webapps.py +356 -0
  22. peak/cli/resources/workflows.py +703 -0
  23. peak/cli/ruff.toml +11 -0
  24. peak/cli/version.py +49 -0
  25. peak/compression.py +162 -0
  26. peak/config.py +24 -0
  27. peak/constants.py +105 -0
  28. peak/exceptions.py +217 -0
  29. peak/handler.py +358 -0
  30. peak/helpers.py +184 -0
  31. peak/logger.py +48 -0
  32. peak/press/__init__.py +28 -0
  33. peak/press/apps.py +669 -0
  34. peak/press/blocks.py +707 -0
  35. peak/press/deployments.py +145 -0
  36. peak/press/specs.py +260 -0
  37. peak/py.typed +0 -0
  38. peak/resources/__init__.py +28 -0
  39. peak/resources/artifacts.py +343 -0
  40. peak/resources/images.py +675 -0
  41. peak/resources/webapps.py +278 -0
  42. peak/resources/workflows.py +625 -0
  43. peak/session.py +259 -0
  44. peak/telemetry.py +201 -0
  45. peak/template.py +231 -0
  46. peak/validators.py +48 -0
  47. peak_sdk-1.0.0.dist-info/LICENSE +201 -0
  48. peak_sdk-1.0.0.dist-info/METADATA +199 -0
  49. peak_sdk-1.0.0.dist-info/RECORD +51 -0
  50. peak_sdk-1.0.0.dist-info/WHEEL +4 -0
  51. peak_sdk-1.0.0.dist-info/entry_points.txt +3 -0
peak/session.py ADDED
@@ -0,0 +1,259 @@
1
+ #
2
+ # # Copyright © 2023 Peak AI Limited. or its affiliates. All Rights Reserved.
3
+ # #
4
+ # # Licensed under the Apache License, Version 2.0 (the "License"). You
5
+ # # may not use this file except in compliance with the License. A copy of
6
+ # # the License is located at:
7
+ # #
8
+ # # https://github.com/PeakBI/peak-sdk/blob/main/LICENSE
9
+ # #
10
+ # # or in the "license" file accompanying this file. This file is
11
+ # # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
12
+ # # ANY KIND, either express or implied. See the License for the specific
13
+ # # language governing permissions and limitations under the License.
14
+ # #
15
+ # # This file is part of the peak-sdk.
16
+ # # see (https://github.com/PeakBI/peak-sdk)
17
+ # #
18
+ # # You should have received a copy of the APACHE LICENSE, VERSION 2.0
19
+ # # along with this program. If not, see <https://apache.org/licenses/LICENSE-2.0>
20
+ #
21
+ """Session module for Peak API."""
22
+ from __future__ import annotations
23
+
24
+ import os
25
+ from pathlib import Path
26
+ from typing import Any, Dict, Iterator, List, Optional
27
+
28
+ from peak import exceptions
29
+ from peak.constants import DOWNLOAD_CHUNK_SIZE, ContentType, HttpMethods, Stage
30
+ from peak.handler import Handler
31
+ from peak.helpers import get_base_domain
32
+ from peak.logger import logger
33
+
34
+ DEFAULT_SESSION = None
35
+
36
+
37
+ def _get_default_session() -> Session:
38
+ """Get the global default session object. Creates one if not already created and re-uses it.
39
+
40
+ Returns:
41
+ Session: The default session object
42
+ """
43
+ global DEFAULT_SESSION # noqa: PLW0603
44
+ if DEFAULT_SESSION is None:
45
+ logger.debug("Creating DEFAULT_SESSION object")
46
+ DEFAULT_SESSION = Session()
47
+
48
+ logger.debug("DEFAULT_SESSION already present, reusing the object")
49
+ return DEFAULT_SESSION
50
+
51
+
52
+ class Session:
53
+ """A session stores credentials which are used to authenticate the requests.
54
+
55
+ By default, a DEFAULT_SESSION is created which reads the credentials from the env variables.
56
+ Custom Session objects can be created and used if you want to work with multiple tenants.
57
+ """
58
+
59
+ auth_token: str
60
+ stage: Stage
61
+ base_domain: str
62
+ handler: Handler
63
+
64
+ def __init__(
65
+ self,
66
+ auth_token: Optional[str] = None,
67
+ stage: Optional[str] = None,
68
+ ) -> None:
69
+ """Initialize a session for the Peak API.
70
+
71
+ Args:
72
+ auth_token (str | None): Authentication token. Both API Key and Bearer tokens are supported.
73
+ Picks up from `API_KEY` environment variable if not provided.
74
+ stage (str | None): Name of the stage where tenant is created. Default is `prod`.
75
+ """
76
+ self.base_domain: str = ""
77
+ self._set_auth_token(auth_token)
78
+ self._set_stage(stage)
79
+ self._set_base_domain()
80
+ self.handler = Handler()
81
+
82
+ def create_request(
83
+ self,
84
+ endpoint: str,
85
+ method: HttpMethods,
86
+ content_type: ContentType,
87
+ *,
88
+ params: Optional[Dict[str, Any]] = None,
89
+ body: Optional[Dict[str, Any]] = None,
90
+ path: Optional[str] = None,
91
+ ignore_files: Optional[list[str]] = None,
92
+ subdomain: Optional[str] = "service",
93
+ ) -> Any:
94
+ """Prepares a request to be sent over the network.
95
+
96
+ Adds auth_token to the headers and creates URL using STAGE.
97
+ To be used with endpoints which returns JSON parsable response.
98
+
99
+ Args:
100
+ endpoint (str): The endpoint to send the request to.
101
+ method (HttpMethods): The HTTP method to use.
102
+ content_type (ContentType): The content type of the request.
103
+ params (Dict[str, Any], optional): params to send to the request, defaults to None
104
+ body (Dict[str, Any], optional): body to send to the request, defaults to None
105
+ path (Optional[str] optional): path to the file or folder that will be compressed and used as artifact, required for multipart requests.
106
+ ignore_files(Optional[list[str]]): Ignore files to be used when creating artifact, used only for multipart requests.
107
+ subdomain (Optional[str]): Subdomain for the endpoint. Defaults to `service`.
108
+
109
+ Returns:
110
+ Any: response dict object.
111
+ """
112
+ headers: Dict[str, str] = {"Authorization": self.auth_token}
113
+ base_domain: str = get_base_domain(stage=self.stage.value, subdomain=subdomain)
114
+ url: str = f"{base_domain}/{endpoint}"
115
+ return self.handler.make_request(
116
+ url,
117
+ method,
118
+ content_type=content_type,
119
+ headers=headers,
120
+ params=params or {},
121
+ body=body or {},
122
+ path=path,
123
+ ignore_files=ignore_files,
124
+ session_meta={
125
+ "stage": self.stage,
126
+ },
127
+ ).json()
128
+
129
+ def create_generator_request(
130
+ self,
131
+ endpoint: str,
132
+ method: HttpMethods,
133
+ content_type: ContentType,
134
+ response_key: str,
135
+ *,
136
+ params: Optional[Dict[str, Any]] = None,
137
+ body: Optional[Dict[str, Any]] = None,
138
+ path: Optional[str] = None,
139
+ subdomain: Optional[str] = "service",
140
+ ) -> Iterator[Dict[str, Any]]:
141
+ """Prepares a request to be sent over the network.
142
+
143
+ Adds auth_token to the headers and creates URL using STAGE.
144
+ Returns an iterator which automatically handles pagination and returns a new page at each iteration.
145
+ To be used with list endpoints only, which returns `pageNumber`, `pageCount` keys in response.
146
+
147
+ # noqa: DAR201
148
+
149
+ Args:
150
+ endpoint (str): The endpoint to send the request to.
151
+ method (HttpMethods): The HTTP method to use.
152
+ content_type (ContentType): The content type of the request.
153
+ response_key (str): key in the response dict which contains actual list data.
154
+ params (Optional[Dict[str, Any]]): params to send to the request.
155
+ body (Optional[Dict[str, Any]]): body to send to the request.
156
+ path (Optional[str]): path to the file or folder that will be compressed and used as artifact.
157
+ subdomain (Optional[str]): Subdomain for the endpoint. Defaults to `service`.
158
+
159
+ Yields:
160
+ Iterator[Dict[str, Any]]: paginated response json, element wise.
161
+
162
+ Raises:
163
+ StopIteration: There are no more pages to list
164
+ """
165
+ page_number: int = 1
166
+ page_count: int = 1
167
+ params = params or {}
168
+ while page_number <= page_count:
169
+ params = {**params, "pageNumber": page_number}
170
+ response = self.create_request(
171
+ endpoint,
172
+ method,
173
+ content_type,
174
+ params=params,
175
+ body=body,
176
+ path=path,
177
+ subdomain=subdomain,
178
+ )
179
+ page_count = response["pageCount"]
180
+ yield from response[response_key]
181
+ page_number += 1
182
+ return f"No more {response_key} to list"
183
+
184
+ def create_download_request(
185
+ self,
186
+ endpoint: str,
187
+ method: HttpMethods,
188
+ content_type: ContentType,
189
+ download_path: str,
190
+ params: Optional[Dict[str, Any]] = None,
191
+ body: Optional[Dict[str, Any]] = None,
192
+ ) -> None:
193
+ """Prepares a request to be sent over the network.
194
+
195
+ Adds auth_token to the headers and creates URL using STAGE.
196
+ To be used for file download requests.
197
+
198
+ Args:
199
+ endpoint (str): The endpoint to send the request to.
200
+ method (HttpMethods): The HTTP method to use.
201
+ content_type (ContentType): The content type of the request.
202
+ download_path (str): Path where the downloaded file will be stored.
203
+ params (Dict[str, Any], optional): params to send to the request, defaults to None
204
+ body (Dict[str, Any], optional): body to send to the request, only used in multipart requests, defaults to None
205
+
206
+ Raises:
207
+ InvalidPathException: The download_path is invalid.
208
+ """
209
+ headers: Dict[str, str] = {"Authorization": self.auth_token}
210
+ url: str = f"{self.base_domain}/{endpoint}"
211
+ response: Any = self.handler.make_request(
212
+ url,
213
+ method,
214
+ content_type=content_type,
215
+ headers=headers,
216
+ params=params or {},
217
+ body=body or {},
218
+ request_kwargs={
219
+ "stream": True,
220
+ "allow_redirects": True,
221
+ },
222
+ )
223
+ try:
224
+ with Path(download_path).open("wb") as fd:
225
+ for chunk in response.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
226
+ fd.write(chunk)
227
+ except IsADirectoryError:
228
+ raise exceptions.InvalidPathException(
229
+ download_path,
230
+ "Path should include the name with which the downloaded file should be stored.",
231
+ ) from None
232
+
233
+ def _set_base_domain(self) -> None:
234
+ self.base_domain = get_base_domain(stage=self.stage.value)
235
+
236
+ def _set_auth_token(self, auth_token: Optional[str]) -> None:
237
+ if auth_token is not None:
238
+ self.auth_token = auth_token
239
+ return
240
+
241
+ logger.info("auth_token not given, searching for API_KEY in env variables")
242
+ if not os.environ.get("API_KEY"):
243
+ raise exceptions.MissingEnvironmentVariableException(env_var="API_KEY")
244
+ self.auth_token = os.environ["API_KEY"]
245
+
246
+ def _set_stage(self, stage: Optional[str]) -> None:
247
+ if stage is not None:
248
+ self.stage = Stage(stage)
249
+ return
250
+
251
+ logger.info("stage not given, searching for STAGE in env variables")
252
+ if not os.environ.get("STAGE"):
253
+ logger.info("STAGE environment variable is not set, defaulting to PROD")
254
+ self.stage = Stage.PROD
255
+ return
256
+ self.stage = Stage(os.environ["STAGE"])
257
+
258
+
259
+ __all__: List[str] = ["Session", "_get_default_session"]
peak/telemetry.py ADDED
@@ -0,0 +1,201 @@
1
+ #
2
+ # # Copyright © 2023 Peak AI Limited. or its affiliates. All Rights Reserved.
3
+ # #
4
+ # # Licensed under the Apache License, Version 2.0 (the "License"). You
5
+ # # may not use this file except in compliance with the License. A copy of
6
+ # # the License is located at:
7
+ # #
8
+ # # https://github.com/PeakBI/peak-sdk/blob/main/LICENSE
9
+ # #
10
+ # # or in the "license" file accompanying this file. This file is
11
+ # # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
12
+ # # ANY KIND, either express or implied. See the License for the specific
13
+ # # language governing permissions and limitations under the License.
14
+ # #
15
+ # # This file is part of the peak-sdk.
16
+ # # see (https://github.com/PeakBI/peak-sdk)
17
+ # #
18
+ # # You should have received a copy of the APACHE LICENSE, VERSION 2.0
19
+ # # along with this program. If not, see <https://apache.org/licenses/LICENSE-2.0>
20
+ #
21
+ """Decorator for sending telemetry data for each request."""
22
+ from __future__ import annotations
23
+
24
+ import platform
25
+ import threading
26
+ import uuid
27
+ from contextlib import suppress
28
+ from functools import wraps
29
+ from typing import Any, Callable, Dict, Optional
30
+
31
+ import requests
32
+
33
+ import peak.config
34
+ from peak.constants import ContentType, HttpMethods, Stage
35
+ from peak.exceptions import BaseHttpException
36
+ from peak.helpers import get_base_domain
37
+
38
+ from ._version import __version__
39
+
40
+ F = Callable[..., requests.Response]
41
+ session_id = str(uuid.uuid4())
42
+
43
+
44
+ def get_status_code(error: Optional[Exception]) -> int | None:
45
+ """It takes an exception object and returns the status code associated with it.
46
+
47
+ Args:
48
+ error (Optional[Exception]): The exception object to check
49
+
50
+ Returns:
51
+ int | None: The status code related to the error or None if no status code found
52
+ """
53
+ if not error:
54
+ return 200
55
+
56
+ if isinstance(error, BaseHttpException):
57
+ return error.STATUS_CODE
58
+
59
+ return None
60
+
61
+
62
+ def telemetry(make_request: F) -> F:
63
+ """A decorator that wraps over the make_request function to send telemetry requests as required.
64
+
65
+ Args:
66
+ make_request (F): The make_request function to wrap in the decorator
67
+
68
+ Returns:
69
+ F: the wrapped function that sends telemetry data for each request
70
+ """
71
+
72
+ def get_telemetry_url(session_meta: Optional[Dict[str, Any]] = None) -> str:
73
+ """Returns the telemetry url for the given stage.
74
+
75
+ Args:
76
+ session_meta (Optional[Dict[str, Any]]): Session metadata object that contains information like stage
77
+
78
+ Returns:
79
+ str: The telemetry URL
80
+ """
81
+ stage = Stage.PROD
82
+ if session_meta:
83
+ stage = session_meta["stage"] if "stage" in session_meta else Stage.PROD
84
+ base_domain = get_base_domain(stage.value, "service")
85
+ return f"{base_domain}/resource-usage/api/v1/telemetry"
86
+
87
+ def get_telemetry_data() -> Dict[str, Any]:
88
+ return {
89
+ "sdkVersion": __version__,
90
+ "os": platform.platform(),
91
+ "hostname": platform.uname().node,
92
+ "pythonVersion": platform.python_version(),
93
+ "sessionId": session_id,
94
+ "requestId": str(uuid.uuid4()),
95
+ }
96
+
97
+ @wraps(make_request)
98
+ def wrapper(
99
+ self: Any,
100
+ url: str,
101
+ method: HttpMethods,
102
+ content_type: ContentType,
103
+ headers: Optional[Dict[str, str]] = None,
104
+ params: Optional[Dict[str, Any]] = None,
105
+ body: Optional[Dict[str, Any]] = None,
106
+ path: Optional[str] = None,
107
+ request_kwargs: Optional[Dict[str, int | bool | str | float]] = None,
108
+ ignore_files: Optional[list[str]] = None,
109
+ session_meta: Optional[Dict[str, Any]] = None,
110
+ ) -> requests.Response:
111
+ """A decorator that wraps over the make_request function to send telemetry requests as required.
112
+
113
+ Args:
114
+ self (Any): the object instance of Handler class on which the make_request call is being made
115
+ url (str): url to send the request to
116
+ method (HttpMethods): The HTTP method to use, e.g. get, post, put, delete
117
+ content_type (ContentType): content type of the request
118
+ headers (Dict[str, str]): headers to send with the request
119
+ params (Dict[str, Any]): params to send to the request
120
+ body (Dict[str, Any]): body to send to the request
121
+ path (str): path to the file or folder that will be compressed and used as artifact, defaults to None
122
+ request_kwargs(Dict[str, int | bool | str | float] | None): extra arguments to be passed when making the request.
123
+ ignore_files(Optional[list[str]]): Ignore files to be used when creating artifact
124
+ session_meta(Dict[str, Any]): Metadata about the session object, like - stage
125
+
126
+ Returns:
127
+ requests.Response: response json
128
+
129
+ Raises:
130
+ BaseHttpException: The http request failed.
131
+ Exception: Some other error occurred.
132
+ """
133
+
134
+ def make_telemetry_request(
135
+ res: Optional[requests.Response] = None,
136
+ error: Optional[Exception] = None,
137
+ ) -> None:
138
+ telemetry_url = get_telemetry_url(session_meta)
139
+
140
+ telemetry_body = {
141
+ "response": res.json() if method != HttpMethods.GET and res else None,
142
+ "error": str(error) if error else None,
143
+ "url": url,
144
+ "requestMethod": method.value,
145
+ "statusCode": get_status_code(error),
146
+ "source": peak.config.SOURCE.value,
147
+ **get_telemetry_data(),
148
+ }
149
+
150
+ with suppress(Exception):
151
+ make_request(
152
+ self,
153
+ telemetry_url,
154
+ HttpMethods.POST,
155
+ ContentType.APPLICATION_JSON,
156
+ headers=headers,
157
+ body=telemetry_body,
158
+ )
159
+
160
+ def trigger_usage_collection(
161
+ res: Optional[Any] = None,
162
+ error: Optional[Exception] = None,
163
+ ) -> None:
164
+ thr = threading.Thread(
165
+ target=make_telemetry_request,
166
+ kwargs={
167
+ "res": res,
168
+ "error": error,
169
+ },
170
+ )
171
+
172
+ thr.start()
173
+
174
+ try:
175
+ custom_headers = {f"x-peak-{key}": value for (key, value) in get_telemetry_data().items()}
176
+ custom_headers = {
177
+ **custom_headers,
178
+ **(headers or {}),
179
+ }
180
+
181
+ res = make_request(
182
+ self,
183
+ url,
184
+ method,
185
+ content_type=content_type,
186
+ headers=custom_headers,
187
+ params=params or {},
188
+ body=body or {},
189
+ path=path,
190
+ ignore_files=ignore_files,
191
+ request_kwargs=request_kwargs,
192
+ )
193
+
194
+ trigger_usage_collection(res=res)
195
+ except Exception as e:
196
+ trigger_usage_collection(error=e)
197
+ raise
198
+ else:
199
+ return res
200
+
201
+ return wrapper
peak/template.py ADDED
@@ -0,0 +1,231 @@
1
+ #
2
+ # # Copyright © 2023 Peak AI Limited. or its affiliates. All Rights Reserved.
3
+ # #
4
+ # # Licensed under the Apache License, Version 2.0 (the "License"). You
5
+ # # may not use this file except in compliance with the License. A copy of
6
+ # # the License is located at:
7
+ # #
8
+ # # https://github.com/PeakBI/peak-sdk/blob/main/LICENSE
9
+ # #
10
+ # # or in the "license" file accompanying this file. This file is
11
+ # # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
12
+ # # ANY KIND, either express or implied. See the License for the specific
13
+ # # language governing permissions and limitations under the License.
14
+ # #
15
+ # # This file is part of the peak-sdk.
16
+ # # see (https://github.com/PeakBI/peak-sdk)
17
+ # #
18
+ # # You should have received a copy of the APACHE LICENSE, VERSION 2.0
19
+ # # along with this program. If not, see <https://apache.org/licenses/LICENSE-2.0>
20
+ #
21
+ """Template module which handles all things related to templates."""
22
+ from __future__ import annotations
23
+
24
+ import os
25
+ import re
26
+ from pathlib import Path
27
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
28
+
29
+ import jinja2
30
+ import yaml
31
+ from jinja2 import Environment
32
+ from jinja2.ext import Extension
33
+
34
+ from peak import exceptions
35
+
36
+
37
+ def _parse_jinja_template(template_path: Path, params: Dict[str, Any]) -> str:
38
+ """Read, parse and render the Jinja template text."""
39
+ jinja_loader = _CustomJinjaLoader()
40
+ jinja_env = jinja2.Environment( # TODO: show warning if variable not found in params # noqa: TD002, TD003, RUF100
41
+ loader=jinja_loader,
42
+ autoescape=False, # noqa: S701
43
+ extensions=[_IncludeWithIndentation],
44
+ )
45
+ jinja_template: jinja2.Template = jinja_env.get_template(str(template_path))
46
+ return jinja_template.render(params, env=os.environ)
47
+
48
+
49
+ def load_template(file: Union[Path, str], params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
50
+ """Load a template file through `Jinja` into a dictionary.
51
+
52
+ This function performs the following steps:
53
+ * Passes the `YAML` file to be loaded and parsed through **`Jinja`**
54
+ * **`Jinja`** substitutes the variables with their values as they are found in `params`
55
+ * Loads any other files that are referenced using the Jinja `{% include %}` directive, if it is present.
56
+ * Updates the `context` key path within `image` definitions with respect to its relative parent file path.
57
+ * Loads the rendered YAML file into a `dictionary`.
58
+
59
+ Args:
60
+ file (Union[Path, str]): Path to the templated `YAML` file to be loaded.
61
+ params (Dict[str, Any] | None, optional): Named parameters to be passed to Jinja. Defaults to `{}`.
62
+
63
+ Returns:
64
+ Dict[str, Any]: Dictionary containing the rendered YAML file
65
+ """
66
+ params = {} if params is None else params
67
+ file = Path(file)
68
+ template: str = _parse_jinja_template(file, params)
69
+ return yaml.safe_load(template) # type: ignore[no-any-return]
70
+
71
+
72
+ class _CustomJinjaLoader(jinja2.BaseLoader):
73
+ """Custom Jinja loader class which handles the include directive.
74
+
75
+ Inspired from the jinja2.FileSystemLoader class.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ search_path: Optional[List[str]] = None,
81
+ encoding: str = "utf-8",
82
+ ) -> None:
83
+ """Initialize all variables.
84
+
85
+ Args:
86
+ search_path (List[str] | None, optional): Path(s) of the directory to search file in.
87
+ encoding (str): Encoding to use when reading files. Defaults to "utf-8".
88
+ """
89
+ if search_path is None:
90
+ search_path = ["."]
91
+
92
+ self.search_path: List[str] = [os.fspath(p) for p in search_path]
93
+ self.encoding: str = encoding
94
+ self.root_file_path: Optional[str] = None
95
+ self.build_context_regex: re.Pattern[str] = re.compile(r"^(\s*context\s*:\s+)(.+)$", flags=re.MULTILINE)
96
+ self.seen_files: Set[Path] = set()
97
+
98
+ def _update_image_build_context(self, source: str, file_parent_dir: str) -> str:
99
+ """Updates the context key in image definition to the relative path of the file where it is being imported.
100
+
101
+ Args:
102
+ source (str): Content of the file
103
+ file_parent_dir (str): Directory where the file is located
104
+
105
+ Returns:
106
+ str: Content of the file with updated context value if present.
107
+ """
108
+ if self.root_file_path is None:
109
+ return source
110
+
111
+ def substitute(match: re.Match[str]) -> str:
112
+ """Substitute the context key with the relative path of the file where it is being imported."""
113
+ context_path: str = match.group(2)
114
+ context_path = context_path.strip().strip('"').strip("'")
115
+ final_path: str = str(Path(file_parent_dir) / Path(context_path))
116
+ final_relative_path: str = os.path.relpath(final_path, self.root_file_path)
117
+ return f"{match.group(1)}{final_relative_path}"
118
+
119
+ return self.build_context_regex.sub(substitute, source)
120
+
121
+ def get_source(self, _: jinja2.Environment, template_path: str) -> Tuple[str, str, Callable[[], bool]]:
122
+ """Searches and reads the template file.
123
+
124
+ Args:
125
+ _ (jinja2.Environment): Jinja environment variable.
126
+ template_path (str): Path of the template file.
127
+
128
+ # noqa: DAR401
129
+ Raises:
130
+ jinja2.TemplateNotFound: The template file is not found on given path.
131
+
132
+ Returns:
133
+ Tuple[str, str, Callable[[], bool]]: Tuple containing 3 variables
134
+ 1. Content of the template file
135
+ 2. Normalized path where the file was found
136
+ 3. Function which checks if the file was modified (not useful in this case)
137
+ """
138
+ template_path = template_path.strip()
139
+ for search_path in self.search_path:
140
+ # Use posixpath even on Windows to avoid "drive:" or UNC
141
+ # segments breaking out of the search directory.
142
+ file_path: str = str(Path(search_path) / Path(template_path))
143
+
144
+ if Path(file_path).is_file():
145
+ break
146
+ else:
147
+ error_msg: str = f"File does not exist at path: {template_path!r}"
148
+ raise jinja2.TemplateNotFound(error_msg)
149
+
150
+ file_path_obj = Path(file_path)
151
+ absolute_file_path: Path = file_path_obj.resolve()
152
+
153
+ if absolute_file_path in self.seen_files:
154
+ error_msg = f"Failed to render template, circular include directive found at {absolute_file_path!r}"
155
+ raise exceptions.InvalidTemplateException(error_msg)
156
+
157
+ file_parent_dir = str(file_path_obj.parent)
158
+ with Path(file_path).open(encoding=self.encoding) as f:
159
+ contents: str = f.read()
160
+ contents = self._update_image_build_context(contents, file_parent_dir)
161
+
162
+ self.seen_files.add(absolute_file_path)
163
+ self.search_path.append(file_parent_dir)
164
+ if self.root_file_path is None:
165
+ self.root_file_path = file_parent_dir
166
+
167
+ return contents, os.path.normpath(file_path), lambda: True
168
+
169
+
170
+ class _IncludeWithIndentation(Extension):
171
+ """Override Jinja include directive to preserve indentation.
172
+
173
+ Inspired from: https://github.com/stereobutter/jinja2_workarounds
174
+ """
175
+
176
+ @staticmethod
177
+ def _include_statement_regex(block_start: str, block_end: str) -> re.Pattern[str]:
178
+ """Get the compiled regex for finding the include directives in template."""
179
+ return re.compile(
180
+ rf"""
181
+ (^.*)
182
+ (?=
183
+ (
184
+ {re.escape(block_start)}
185
+ (?P<block_start_modifier> [\+|-]?)
186
+ (?P<statement>
187
+ \s*include
188
+ \s+.*?
189
+ )
190
+ (?P<block_end_modifier> [\+|-]?)
191
+ {re.escape(block_end)}
192
+ )
193
+ )
194
+ .*$
195
+ """,
196
+ flags=re.MULTILINE | re.VERBOSE,
197
+ )
198
+
199
+ def preprocess(self, source: str, _: Optional[str], __: Optional[str] = None) -> str:
200
+ """Enclose all include directives.
201
+
202
+ For all the regex matches in the text, enclose the include block
203
+ with indent filter blocks and return updated text.
204
+ """
205
+ env: Environment = self.environment
206
+
207
+ block_start: str = env.block_start_string
208
+ block_end: str = env.block_end_string
209
+ pattern: re.Pattern[str] = self._include_statement_regex(block_start=block_start, block_end=block_end)
210
+ re.compile("\n")
211
+
212
+ def add_indentation_filter(match: re.Match[str]) -> str:
213
+ """Add indent filter to include blocks."""
214
+ content_before_include: str | Any = match.group(1)
215
+ include_statement: str | Any = match.group("statement").replace("indent content", "")
216
+
217
+ block_start_modifier: str | Any = match.group("block_start_modifier") or ""
218
+ block_end_modifier: str | Any = match.group("block_end_modifier") or ""
219
+
220
+ start_filter: str = (
221
+ f"{block_start + block_start_modifier} filter indent({len(content_before_include)}) {block_end}"
222
+ )
223
+ include_block: str = f"{block_start} {include_statement} {block_end}"
224
+ end_filter: str = f"{block_start} endfilter {block_end_modifier + block_end}"
225
+
226
+ return f"{content_before_include}{start_filter}{include_block}{end_filter}"
227
+
228
+ return pattern.sub(add_indentation_filter, source)
229
+
230
+
231
+ __all__: List[str] = ["load_template"]