snowflake-cli-labs 2.7.0rc3__py3-none-any.whl → 2.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. snowflake/cli/__about__.py +1 -1
  2. snowflake/cli/api/feature_flags.py +1 -2
  3. snowflake/cli/api/project/definition.py +3 -36
  4. snowflake/cli/api/project/errors.py +16 -1
  5. snowflake/cli/api/project/schemas/entities/application_entity.py +5 -11
  6. snowflake/cli/api/project/schemas/entities/application_package_entity.py +5 -2
  7. snowflake/cli/api/project/schemas/entities/common.py +15 -22
  8. snowflake/cli/api/project/schemas/native_app/application.py +10 -2
  9. snowflake/cli/api/project/schemas/native_app/native_app.py +13 -2
  10. snowflake/cli/api/project/schemas/native_app/package.py +24 -1
  11. snowflake/cli/api/project/schemas/project_definition.py +23 -40
  12. snowflake/cli/api/project/schemas/snowpark/callable.py +1 -3
  13. snowflake/cli/api/project/schemas/updatable_model.py +148 -5
  14. snowflake/cli/api/project/util.py +55 -7
  15. snowflake/cli/api/rendering/jinja.py +1 -0
  16. snowflake/cli/api/rendering/project_templates.py +8 -7
  17. snowflake/cli/api/rendering/sql_templates.py +8 -4
  18. snowflake/cli/api/utils/definition_rendering.py +50 -11
  19. snowflake/cli/api/utils/models.py +10 -7
  20. snowflake/cli/api/utils/templating_functions.py +144 -0
  21. snowflake/cli/app/build_and_push.sh +8 -0
  22. snowflake/cli/app/snow_connector.py +14 -10
  23. snowflake/cli/plugins/init/commands.py +13 -7
  24. snowflake/cli/plugins/nativeapp/manager.py +93 -10
  25. snowflake/cli/plugins/nativeapp/project_model.py +13 -3
  26. snowflake/cli/plugins/nativeapp/run_processor.py +22 -51
  27. snowflake/cli/plugins/nativeapp/v2_conversions/v2_to_v1_decorator.py +7 -18
  28. snowflake/cli/plugins/nativeapp/version/version_processor.py +4 -0
  29. snowflake/cli/plugins/snowpark/commands.py +6 -3
  30. {snowflake_cli_labs-2.7.0rc3.dist-info → snowflake_cli_labs-2.8.0.dist-info}/METADATA +1 -1
  31. {snowflake_cli_labs-2.7.0rc3.dist-info → snowflake_cli_labs-2.8.0.dist-info}/RECORD +34 -32
  32. {snowflake_cli_labs-2.7.0rc3.dist-info → snowflake_cli_labs-2.8.0.dist-info}/WHEEL +0 -0
  33. {snowflake_cli_labs-2.7.0rc3.dist-info → snowflake_cli_labs-2.8.0.dist-info}/entry_points.txt +0 -0
  34. {snowflake_cli_labs-2.7.0rc3.dist-info → snowflake_cli_labs-2.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -17,7 +17,7 @@ from __future__ import annotations
17
17
  import codecs
18
18
  import os
19
19
  import re
20
- from typing import Optional
20
+ from typing import List, Optional
21
21
  from urllib.parse import quote
22
22
 
23
23
  IDENTIFIER = r'((?:"[^"]*(?:""[^"]*)*")|(?:[A-Za-z_][\w$]{0,254}))'
@@ -42,12 +42,20 @@ def encode_uri_component(s: str) -> str:
42
42
  return quote(s, safe="!~*'()")
43
43
 
44
44
 
45
- def clean_identifier(input_: str):
45
+ def sanitize_identifier(input_: str):
46
46
  """
47
- Removes characters that cannot be used in an unquoted identifier,
48
- converting to lowercase as well.
47
+ Removes characters that cannot be used in an unquoted identifier.
48
+ If the identifier does not start with a letter or underscore, prefix it with an underscore.
49
+ Limits the identifier to 255 characters.
49
50
  """
50
- return re.sub(r"[^a-z0-9_$]", "", f"{input_}".lower())
51
+ value = re.sub(r"[^a-zA-Z0-9_$]", "", f"{input_}")
52
+
53
+ # if it does not start with a letter or underscore, prefix it with an underscore
54
+ if not value or not re.match(r"[a-zA-Z_]", value[0]):
55
+ value = f"_{value}"
56
+
57
+ # limit it to 255 characters
58
+ return value[:255]
51
59
 
52
60
 
53
61
  def is_valid_unquoted_identifier(identifier: str) -> bool:
@@ -88,6 +96,18 @@ def is_valid_object_name(name: str, max_depth=2, allow_quoted=True) -> bool:
88
96
  return re.fullmatch(pattern, name) is not None
89
97
 
90
98
 
99
+ def to_quoted_identifier(input_value: str) -> str:
100
+ """
101
+ Turn the input into a valid quoted identifier.
102
+ If it is already a valid quoted identifier,
103
+ return it as is.
104
+ """
105
+ if is_valid_quoted_identifier(input_value):
106
+ return input_value
107
+
108
+ return '"' + input_value.replace('"', '""') + '"'
109
+
110
+
91
111
  def to_identifier(name: str) -> str:
92
112
  """
93
113
  Converts a name to a valid Snowflake identifier. If the name is already a valid
@@ -96,8 +116,15 @@ def to_identifier(name: str) -> str:
96
116
  if is_valid_identifier(name):
97
117
  return name
98
118
 
99
- # double quote the identifier
100
- return '"' + name.replace('"', '""') + '"'
119
+ return to_quoted_identifier(name)
120
+
121
+
122
+ def identifier_to_str(identifier: str) -> str:
123
+ if is_valid_quoted_identifier(identifier):
124
+ unquoted_id = identifier[1:-1]
125
+ return unquoted_id.replace('""', '"')
126
+
127
+ return identifier
101
128
 
102
129
 
103
130
  def append_to_identifier(identifier: str, suffix: str) -> str:
@@ -183,6 +210,27 @@ def get_env_username() -> Optional[str]:
183
210
  return first_set_env("USER", "USERNAME", "LOGNAME")
184
211
 
185
212
 
213
+ def concat_identifiers(identifiers: list[str]) -> str:
214
+ """
215
+ Concatenate multiple identifiers.
216
+ If any of them is quoted identifier or contains unsafe characters, quote the result.
217
+ Otherwise, the resulting identifier will be unquoted.
218
+ """
219
+ quotes_found = False
220
+ stringified_identifiers: List[str] = []
221
+
222
+ for identifier in identifiers:
223
+ if is_valid_quoted_identifier(identifier):
224
+ quotes_found = True
225
+ stringified_identifiers.append(identifier_to_str(identifier))
226
+
227
+ concatenated_ids_str = "".join(stringified_identifiers)
228
+ if quotes_found:
229
+ return to_quoted_identifier(concatenated_ids_str)
230
+
231
+ return to_identifier(concatenated_ids_str)
232
+
233
+
186
234
  SUPPORTED_VERSIONS = [1]
187
235
 
188
236
 
@@ -24,6 +24,7 @@ from jinja2 import Environment, StrictUndefined, loaders
24
24
  from snowflake.cli.api.secure_path import UNLIMITED, SecurePath
25
25
 
26
26
  CONTEXT_KEY = "ctx"
27
+ FUNCTION_KEY = "fn"
27
28
 
28
29
 
29
30
  def read_file_content(file_name: str):
@@ -28,8 +28,10 @@ from snowflake.cli.api.exceptions import InvalidTemplate
28
28
  from snowflake.cli.api.rendering.jinja import IgnoreAttrEnvironment, env_bootstrap
29
29
  from snowflake.cli.api.secure_path import SecurePath
30
30
 
31
- _PROJECT_TEMPLATE_START = "<!"
32
- _PROJECT_TEMPLATE_END = "!>"
31
+ _VARIABLE_TEMPLATE_START = "<!"
32
+ _VARIABLE_TEMPLATE_END = "!>"
33
+ _BLOCK_TEMPLATE_START = "<!!"
34
+ _BLOCK_TEMPLATE_END = "!!>"
33
35
 
34
36
 
35
37
  def to_snowflake_identifier(value: Optional[str]) -> Optional[str]:
@@ -60,15 +62,14 @@ PROJECT_TEMPLATE_FILTERS = [to_snowflake_identifier]
60
62
 
61
63
 
62
64
  def get_template_cli_jinja_env(template_root: SecurePath) -> Environment:
63
- _random_block = "___very___unique___block___to___disable___logic___blocks___"
64
65
  env = env_bootstrap(
65
66
  IgnoreAttrEnvironment(
66
67
  loader=loaders.FileSystemLoader(searchpath=template_root.path),
67
68
  keep_trailing_newline=True,
68
- variable_start_string=_PROJECT_TEMPLATE_START,
69
- variable_end_string=_PROJECT_TEMPLATE_END,
70
- block_start_string=_random_block,
71
- block_end_string=_random_block,
69
+ variable_start_string=_VARIABLE_TEMPLATE_START,
70
+ variable_end_string=_VARIABLE_TEMPLATE_END,
71
+ block_start_string=_BLOCK_TEMPLATE_START,
72
+ block_end_string=_BLOCK_TEMPLATE_END,
72
73
  undefined=StrictUndefined,
73
74
  )
74
75
  )
@@ -21,12 +21,14 @@ from jinja2 import StrictUndefined, loaders
21
21
  from snowflake.cli.api.cli_global_context import cli_context
22
22
  from snowflake.cli.api.rendering.jinja import (
23
23
  CONTEXT_KEY,
24
+ FUNCTION_KEY,
24
25
  IgnoreAttrEnvironment,
25
26
  env_bootstrap,
26
27
  )
27
28
 
28
29
  _SQL_TEMPLATE_START = "&{"
29
30
  _SQL_TEMPLATE_END = "}"
31
+ RESERVED_KEYS = [CONTEXT_KEY, FUNCTION_KEY]
30
32
 
31
33
 
32
34
  def get_sql_cli_jinja_env(*, loader: Optional[loaders.BaseLoader] = None):
@@ -46,10 +48,12 @@ def get_sql_cli_jinja_env(*, loader: Optional[loaders.BaseLoader] = None):
46
48
 
47
49
  def snowflake_sql_jinja_render(content: str, data: Dict | None = None) -> str:
48
50
  data = data or {}
49
- if CONTEXT_KEY in data:
50
- raise ClickException(
51
- f"{CONTEXT_KEY} in user defined data. The `{CONTEXT_KEY}` variable is reserved for CLI usage."
52
- )
51
+
52
+ for reserved_key in RESERVED_KEYS:
53
+ if reserved_key in data:
54
+ raise ClickException(
55
+ f"{reserved_key} in user defined data. The `{reserved_key}` variable is reserved for CLI usage."
56
+ )
53
57
 
54
58
  context_data = cli_context.template_context
55
59
  context_data.update(data)
@@ -25,13 +25,15 @@ from snowflake.cli.api.project.schemas.project_definition import (
25
25
  ProjectProperties,
26
26
  build_project_definition,
27
27
  )
28
+ from snowflake.cli.api.project.schemas.updatable_model import context
28
29
  from snowflake.cli.api.rendering.jinja import CONTEXT_KEY
29
30
  from snowflake.cli.api.rendering.project_definition_templates import (
30
31
  get_project_definition_cli_jinja_env,
31
32
  )
32
- from snowflake.cli.api.utils.dict_utils import traverse
33
+ from snowflake.cli.api.utils.dict_utils import deep_merge_dicts, traverse
33
34
  from snowflake.cli.api.utils.graph import Graph, Node
34
35
  from snowflake.cli.api.utils.models import ProjectEnvironment
36
+ from snowflake.cli.api.utils.templating_functions import get_templating_functions
35
37
  from snowflake.cli.api.utils.types import Context, Definition
36
38
 
37
39
 
@@ -81,7 +83,17 @@ class TemplatedEnvironment:
81
83
  all_referenced_vars.add(TemplateVar(current_attr_chain))
82
84
  current_attr_chain = None
83
85
  elif (
84
- not isinstance(ast_node, (nodes.Template, nodes.TemplateData, nodes.Output))
86
+ not isinstance(
87
+ ast_node,
88
+ (
89
+ nodes.Template,
90
+ nodes.TemplateData,
91
+ nodes.Output,
92
+ nodes.Call,
93
+ nodes.Const,
94
+ nodes.Filter,
95
+ ),
96
+ )
85
97
  or current_attr_chain is not None
86
98
  ):
87
99
  raise InvalidTemplate(f"Unexpected templating syntax in {template_value}")
@@ -199,7 +211,6 @@ def _build_dependency_graph(
199
211
  dependencies_graph = Graph[TemplateVar]()
200
212
  for variable in all_vars:
201
213
  dependencies_graph.add(Node[TemplateVar](key=variable.key, data=variable))
202
-
203
214
  for variable in all_vars:
204
215
  # If variable is found in os.environ or from cli override, then use the value as is
205
216
  # skip rendering by pre-setting the rendered_value attribute
@@ -262,6 +273,22 @@ def _template_version_warning():
262
273
  )
263
274
 
264
275
 
276
+ def _add_defaults_to_definition(original_definition: Definition) -> Definition:
277
+ with context({"skip_validation_on_templates": True}):
278
+ # pass a flag to Pydantic to skip validation for templated scalars
279
+ # populate the defaults
280
+ project_definition = build_project_definition(**original_definition)
281
+
282
+ definition_with_defaults = project_definition.model_dump(
283
+ exclude_none=True, warnings=False, by_alias=True
284
+ )
285
+ # The main purpose of the above operation was to populate defaults from Pydantic.
286
+ # By merging the original definition back in, we ensure that any transformations
287
+ # that Pydantic would have performed are undone.
288
+ deep_merge_dicts(definition_with_defaults, original_definition)
289
+ return definition_with_defaults
290
+
291
+
265
292
  def render_definition_template(
266
293
  original_definition: Optional[Definition], context_overrides: Context
267
294
  ) -> ProjectProperties:
@@ -276,11 +303,14 @@ def render_definition_template(
276
303
  Environment variables take precedence during the rendering process.
277
304
  """
278
305
 
279
- # protect input from update
306
+ # copy input to protect it from update
280
307
  definition = copy.deepcopy(original_definition)
281
308
 
282
- # start with an environment from overrides and environment variables:
309
+ # collect all the override --env variables passed through CLI input
283
310
  override_env = context_overrides.get(CONTEXT_KEY, {}).get("env", {})
311
+
312
+ # set up Project Environment with empty default_env because
313
+ # default env section from project definition file is still templated at this time
284
314
  environment_overrides = ProjectEnvironment(
285
315
  default_env={}, override_env=override_env
286
316
  )
@@ -288,7 +318,6 @@ def render_definition_template(
288
318
  if definition is None:
289
319
  return ProjectProperties(None, {CONTEXT_KEY: {"env": environment_overrides}})
290
320
 
291
- project_context = {CONTEXT_KEY: definition}
292
321
  template_env = TemplatedEnvironment(get_project_definition_cli_jinja_env())
293
322
 
294
323
  if "definition_version" not in definition or Version(
@@ -304,12 +333,18 @@ def render_definition_template(
304
333
  # also warn on Exception, as it means the user is incorrectly attempting to use templating
305
334
  _template_version_warning()
306
335
 
307
- project_definition = build_project_definition(**original_definition)
336
+ project_definition = build_project_definition(**definition)
337
+ project_context = {CONTEXT_KEY: definition}
308
338
  project_context[CONTEXT_KEY]["env"] = environment_overrides
309
339
  return ProjectProperties(project_definition, project_context)
310
340
 
311
- default_env = definition.get("env", {})
312
- _validate_env_section(default_env)
341
+ definition = _add_defaults_to_definition(definition)
342
+ project_context = {CONTEXT_KEY: definition}
343
+
344
+ _validate_env_section(definition.get("env", {}))
345
+
346
+ # add available templating functions
347
+ project_context["fn"] = get_templating_functions()
313
348
 
314
349
  referenced_vars = _get_referenced_vars_in_definition(template_env, definition)
315
350
 
@@ -338,7 +373,11 @@ def render_definition_template(
338
373
  update_action=lambda val: template_env.render(val, final_context),
339
374
  )
340
375
 
341
- definition["env"] = ProjectEnvironment(default_env, override_env)
342
- project_context[CONTEXT_KEY] = definition
343
376
  project_definition = build_project_definition(**definition)
377
+ project_context[CONTEXT_KEY] = definition
378
+ # Use `ProjectEnvironment` in project context in order to
379
+ # handle env variables overrides from OS env and from CLI arguments.
380
+ project_context[CONTEXT_KEY]["env"] = ProjectEnvironment(
381
+ default_env=project_context[CONTEXT_KEY].get("env"), override_env=override_env
382
+ )
344
383
  return ProjectProperties(project_definition, project_context)
@@ -15,12 +15,12 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import os
18
+ from dataclasses import dataclass
18
19
  from typing import Any, Dict, Optional
19
20
 
20
- from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel
21
21
 
22
-
23
- class ProjectEnvironment(UpdatableModel):
22
+ @dataclass
23
+ class ProjectEnvironment:
24
24
  """
25
25
  This class handles retrieval of project env variables.
26
26
  These env variables can be accessed through templating, as ctx.env.<var_name>
@@ -31,13 +31,16 @@ class ProjectEnvironment(UpdatableModel):
31
31
  - Check for default values from the project definition file.
32
32
  """
33
33
 
34
- override_env: Dict[str, Any] = {}
35
- default_env: Dict[str, Any] = {}
34
+ override_env: Dict[str, Any]
35
+ default_env: Dict[str, Any]
36
36
 
37
37
  def __init__(
38
- self, default_env: Dict[str, Any], override_env: Optional[Dict[str, Any]] = None
38
+ self,
39
+ default_env: Optional[Dict[str, Any]] = None,
40
+ override_env: Optional[Dict[str, Any]] = None,
39
41
  ):
40
- super().__init__(self, default_env=default_env, override_env=override_env or {})
42
+ self.override_env = override_env or {}
43
+ self.default_env = default_env or {}
41
44
 
42
45
  def __getitem__(self, item):
43
46
  if item in self.override_env:
@@ -0,0 +1,144 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Any, List, Optional
18
+
19
+ from snowflake.cli.api.exceptions import InvalidTemplate
20
+ from snowflake.cli.api.project.util import (
21
+ concat_identifiers,
22
+ get_env_username,
23
+ identifier_to_str,
24
+ sanitize_identifier,
25
+ to_identifier,
26
+ )
27
+
28
+
29
+ class TemplatingFunctions:
30
+ """
31
+ This class contains all the functions available for templating.
32
+ Any callable not starting with '_' will automatically be available for users to use.
33
+ """
34
+
35
+ @staticmethod
36
+ def _verify_str_arguments(
37
+ func_name: str,
38
+ args: List[Any],
39
+ *,
40
+ min_count: Optional[int] = None,
41
+ max_count: Optional[int] = None,
42
+ ):
43
+ if min_count is not None and len(args) < min_count:
44
+ raise InvalidTemplate(
45
+ f"{func_name} requires at least {min_count} argument(s)"
46
+ )
47
+
48
+ if max_count is not None and len(args) > max_count:
49
+ raise InvalidTemplate(
50
+ f"{func_name} supports at most {max_count} argument(s)"
51
+ )
52
+
53
+ for arg in args:
54
+ if not isinstance(arg, str):
55
+ raise InvalidTemplate(f"{func_name} only accepts String values")
56
+
57
+ @staticmethod
58
+ def concat_ids(*args):
59
+ """
60
+ input: one or more string arguments (SQL ID or plain String).
61
+ output: a valid SQL ID (quoted or unquoted)
62
+
63
+ Takes on multiple String arguments and concatenate them into one String.
64
+ If any of the Strings is a valid quoted ID, it will be unescaped for the concatenation process.
65
+ The resulting String is then escaped and quoted if:
66
+ - It contains non SQL safe characters
67
+ - Any of the input was a valid quoted identifier.
68
+ """
69
+ TemplatingFunctions._verify_str_arguments("concat_ids", args, min_count=1)
70
+ return concat_identifiers(args)
71
+
72
+ @staticmethod
73
+ def str_to_id(*args):
74
+ """
75
+ input: one string argument. (SQL ID or plain String)
76
+ output: a valid SQL ID (quoted or unquoted)
77
+
78
+ If the input is a valid quoted or valid unquoted identifier, return it as is.
79
+ Otherwise, if the input contains unsafe characters and is not properly quoted,
80
+ then escape it and quote it.
81
+ """
82
+ TemplatingFunctions._verify_str_arguments(
83
+ "str_to_id", args, min_count=1, max_count=1
84
+ )
85
+ return to_identifier(args[0])
86
+
87
+ @staticmethod
88
+ def id_to_str(*args):
89
+ """
90
+ input: one string argument (SQL ID or plain String).
91
+ output: a plain string
92
+
93
+ If the input is a valid SQL ID, then unescape it and return the plain String version.
94
+ Otherwise, return the input as is.
95
+ """
96
+ TemplatingFunctions._verify_str_arguments(
97
+ "id_to_str", args, min_count=1, max_count=1
98
+ )
99
+ return identifier_to_str(args[0])
100
+
101
+ @staticmethod
102
+ def get_username(*args):
103
+ """
104
+ input: one optional string containing the fallback value
105
+ output: current username detected from the Operating System
106
+
107
+ If the current username is not found or is blank, return blank
108
+ or use the fallback value if provided.
109
+ """
110
+ TemplatingFunctions._verify_str_arguments(
111
+ "get_username", args, min_count=0, max_count=1
112
+ )
113
+ fallback_username = args[0] if len(args) > 0 else ""
114
+ return get_env_username() or fallback_username
115
+
116
+ @staticmethod
117
+ def sanitize_id(*args):
118
+ """
119
+ input: one string argument
120
+ output: a valid non-quoted SQL ID
121
+
122
+ Removes any unsafe SQL characters from the input,
123
+ prepend it with an underscore if it does not start with a valid character,
124
+ and limit the result to 255 characters.
125
+ The result is a valid unquoted SQL ID.
126
+ """
127
+ TemplatingFunctions._verify_str_arguments(
128
+ "sanitize_id", args, min_count=1, max_count=1
129
+ )
130
+
131
+ return sanitize_identifier(args[0])
132
+
133
+
134
+ def get_templating_functions():
135
+ """
136
+ Returns a dictionary with all the functions available for templating
137
+ """
138
+ templating_functions = {
139
+ func: getattr(TemplatingFunctions, func)
140
+ for func in dir(TemplatingFunctions)
141
+ if callable(getattr(TemplatingFunctions, func)) and not func.startswith("_")
142
+ }
143
+
144
+ return templating_functions
@@ -0,0 +1,8 @@
1
+ set -e
2
+ export SF_REGISTRY="$(snow spcs image-registry url -c integration)"
3
+ DATABASE=$(echo "${SNOWFLAKE_CONNECTIONS_INTEGRATION_DATABASE}" | tr '[:upper:]' '[:lower:]')
4
+
5
+ echo "Using registry: ${SF_REGISTRY}"
6
+ docker build --platform linux/amd64 -t "${SF_REGISTRY}/${DATABASE}/public/snowcli_repository/test_counter" .
7
+ snow spcs image-registry token --format=json -c integration | docker login "${SF_REGISTRY}/${DATABASE}/public/snowcli_repository" -u 0sessiontoken --password-stdin
8
+ docker push "${SF_REGISTRY}/${DATABASE}/public/snowcli_repository/test_counter"
@@ -97,7 +97,7 @@ def connect_to_snowflake(
97
97
  k: v for k, v in connection_parameters.items() if v is not None
98
98
  }
99
99
 
100
- connection_parameters = _update_connection_details_with_private_key(
100
+ connection_parameters = update_connection_details_with_private_key(
101
101
  connection_parameters
102
102
  )
103
103
 
@@ -163,7 +163,7 @@ def _raise_errors_related_to_session_token(
163
163
  )
164
164
 
165
165
 
166
- def _update_connection_details_with_private_key(connection_parameters: Dict):
166
+ def update_connection_details_with_private_key(connection_parameters: Dict):
167
167
  if "private_key_path" in connection_parameters:
168
168
  if connection_parameters.get("authenticator") == "SNOWFLAKE_JWT":
169
169
  private_key = _load_pem_to_der(connection_parameters["private_key_path"])
@@ -189,13 +189,6 @@ def _load_pem_to_der(private_key_path: str) -> bytes:
189
189
  Given a private key file path (in PEM format), decode key data into DER
190
190
  format
191
191
  """
192
- from cryptography.hazmat.backends import default_backend
193
- from cryptography.hazmat.primitives.serialization import (
194
- Encoding,
195
- NoEncryption,
196
- PrivateFormat,
197
- load_pem_private_key,
198
- )
199
192
 
200
193
  with SecurePath(private_key_path).open(
201
194
  "rb", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB
@@ -222,6 +215,18 @@ def _load_pem_to_der(private_key_path: str) -> bytes:
222
215
  if private_key_pem.startswith(UNENCRYPTED_PKCS8_PK_HEADER):
223
216
  private_key_passphrase = None
224
217
 
218
+ return prepare_private_key(private_key_pem, private_key_passphrase)
219
+
220
+
221
+ def prepare_private_key(private_key_pem, private_key_passphrase=None):
222
+ from cryptography.hazmat.backends import default_backend
223
+ from cryptography.hazmat.primitives.serialization import (
224
+ Encoding,
225
+ NoEncryption,
226
+ PrivateFormat,
227
+ load_pem_private_key,
228
+ )
229
+
225
230
  private_key = load_pem_private_key(
226
231
  private_key_pem,
227
232
  (
@@ -231,7 +236,6 @@ def _load_pem_to_der(private_key_path: str) -> bytes:
231
236
  ),
232
237
  default_backend(),
233
238
  )
234
-
235
239
  return private_key.private_bytes(
236
240
  encoding=Encoding.DER,
237
241
  format=PrivateFormat.PKCS8,
@@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional
20
20
  import typer
21
21
  import yaml
22
22
  from click import ClickException
23
+ from snowflake.cli.__about__ import VERSION
23
24
  from snowflake.cli.api.commands.flags import (
24
25
  NoInteractiveOption,
25
26
  parse_key_value_variables,
@@ -67,7 +68,8 @@ TemplateOption = typer.Option(
67
68
  show_default=False,
68
69
  )
69
70
  SourceOption = typer.Option(
70
- default=DEFAULT_SOURCE,
71
+ DEFAULT_SOURCE,
72
+ "--template-source",
71
73
  help=f"local path to template directory or URL to git repository with templates.",
72
74
  )
73
75
  VariablesOption = variables_option(
@@ -131,13 +133,13 @@ def _fetch_remote_template(
131
133
  return template_root
132
134
 
133
135
 
134
- def _read_template_metadata(template_root: SecurePath) -> Template:
136
+ def _read_template_metadata(template_root: SecurePath, args_error_msg: str) -> Template:
135
137
  """Parse template.yml file."""
136
138
  template_metadata_path = template_root / TEMPLATE_METADATA_FILE_NAME
137
139
  log.debug("Reading template metadata from %s", template_metadata_path.path)
138
140
  if not template_metadata_path.exists():
139
141
  raise InvalidTemplate(
140
- f"Template does not have {TEMPLATE_METADATA_FILE_NAME} file."
142
+ f"File {TEMPLATE_METADATA_FILE_NAME} not found. {args_error_msg}"
141
143
  )
142
144
  with template_metadata_path.open(read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB) as fd:
143
145
  yaml_contents = yaml.safe_load(fd) or {}
@@ -176,7 +178,6 @@ def _determine_variable_values(
176
178
 
177
179
  def _validate_cli_version(required_version: str) -> None:
178
180
  from packaging.version import parse
179
- from snowflake.cli.__about__ import VERSION
180
181
 
181
182
  if parse(required_version) > parse(VERSION):
182
183
  raise ClickException(
@@ -203,6 +204,7 @@ def init(
203
204
  is_remote = any(
204
205
  template_source.startswith(prefix) for prefix in ["git@", "http://", "https://"] # type: ignore
205
206
  )
207
+ args_error_msg = f"Check whether {TemplateOption.param_decls[0]} and {SourceOption.param_decls[0]} arguments are correct."
206
208
 
207
209
  # copy/download template into tmpdir, so it is going to be removed in case command ends with an error
208
210
  with SecurePath.temporary_directory() as tmpdir:
@@ -217,7 +219,9 @@ def init(
217
219
  destination=tmpdir,
218
220
  )
219
221
 
220
- template_metadata = _read_template_metadata(template_root)
222
+ template_metadata = _read_template_metadata(
223
+ template_root, args_error_msg=args_error_msg
224
+ )
221
225
  if template_metadata.minimum_cli_version:
222
226
  _validate_cli_version(template_metadata.minimum_cli_version)
223
227
 
@@ -225,8 +229,10 @@ def init(
225
229
  variables_metadata=template_metadata.variables,
226
230
  variables_from_flags=variables_from_flags,
227
231
  no_interactive=no_interactive,
228
- )
229
- variable_values["project_dir_name"] = SecurePath(path).name
232
+ ) | {
233
+ "project_dir_name": SecurePath(path).name,
234
+ "snowflake_cli_version": VERSION,
235
+ }
230
236
  log.debug(
231
237
  "Rendering template files: %s", ", ".join(template_metadata.files_to_render)
232
238
  )