snowflake-cli-labs 2.6.0rc0__py3-none-any.whl → 2.7.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. snowflake/cli/__about__.py +1 -1
  2. snowflake/cli/api/cli_global_context.py +9 -0
  3. snowflake/cli/api/commands/decorators.py +9 -4
  4. snowflake/cli/api/commands/execution_metadata.py +40 -0
  5. snowflake/cli/api/commands/flags.py +45 -36
  6. snowflake/cli/api/commands/project_initialisation.py +4 -1
  7. snowflake/cli/api/commands/snow_typer.py +20 -9
  8. snowflake/cli/api/config.py +3 -0
  9. snowflake/cli/api/errno.py +27 -0
  10. snowflake/cli/api/feature_flags.py +1 -0
  11. snowflake/cli/api/identifiers.py +20 -3
  12. snowflake/cli/api/output/types.py +9 -0
  13. snowflake/cli/api/project/definition_manager.py +2 -2
  14. snowflake/cli/api/project/project_verification.py +23 -0
  15. snowflake/cli/api/project/schemas/entities/application_entity.py +50 -0
  16. snowflake/cli/api/project/schemas/entities/application_package_entity.py +63 -0
  17. snowflake/cli/api/project/schemas/entities/common.py +85 -0
  18. snowflake/cli/api/project/schemas/entities/entities.py +30 -0
  19. snowflake/cli/api/project/schemas/project_definition.py +114 -22
  20. snowflake/cli/api/project/schemas/streamlit/streamlit.py +5 -4
  21. snowflake/cli/api/project/schemas/template.py +77 -0
  22. snowflake/cli/{plugins/nativeapp/errno.py → api/rendering/__init__.py} +0 -2
  23. snowflake/cli/api/{utils/rendering.py → rendering/jinja.py} +3 -48
  24. snowflake/cli/api/rendering/project_definition_templates.py +39 -0
  25. snowflake/cli/api/rendering/project_templates.py +97 -0
  26. snowflake/cli/api/rendering/sql_templates.py +56 -0
  27. snowflake/cli/api/rest_api.py +84 -25
  28. snowflake/cli/api/sql_execution.py +40 -1
  29. snowflake/cli/api/utils/definition_rendering.py +8 -5
  30. snowflake/cli/app/cli_app.py +0 -2
  31. snowflake/cli/app/commands_registration/builtin_plugins.py +4 -0
  32. snowflake/cli/app/dev/docs/project_definition_docs_generator.py +2 -2
  33. snowflake/cli/app/loggers.py +10 -6
  34. snowflake/cli/app/printing.py +17 -7
  35. snowflake/cli/app/snow_connector.py +9 -1
  36. snowflake/cli/app/telemetry.py +41 -2
  37. snowflake/cli/plugins/connection/commands.py +4 -3
  38. snowflake/cli/plugins/connection/util.py +73 -18
  39. snowflake/cli/plugins/cortex/commands.py +2 -1
  40. snowflake/cli/plugins/git/commands.py +20 -4
  41. snowflake/cli/plugins/git/manager.py +44 -20
  42. snowflake/cli/plugins/init/__init__.py +13 -0
  43. snowflake/cli/plugins/init/commands.py +242 -0
  44. snowflake/cli/plugins/init/plugin_spec.py +30 -0
  45. snowflake/cli/plugins/nativeapp/codegen/artifact_processor.py +40 -0
  46. snowflake/cli/plugins/nativeapp/codegen/compiler.py +57 -27
  47. snowflake/cli/plugins/nativeapp/codegen/sandbox.py +99 -10
  48. snowflake/cli/plugins/nativeapp/codegen/setup/native_app_setup_processor.py +172 -0
  49. snowflake/cli/plugins/nativeapp/codegen/setup/setup_driver.py.source +56 -0
  50. snowflake/cli/plugins/nativeapp/codegen/snowpark/python_processor.py +21 -21
  51. snowflake/cli/plugins/nativeapp/commands.py +69 -6
  52. snowflake/cli/plugins/nativeapp/constants.py +0 -6
  53. snowflake/cli/plugins/nativeapp/exceptions.py +37 -12
  54. snowflake/cli/plugins/nativeapp/init.py +1 -1
  55. snowflake/cli/plugins/nativeapp/manager.py +114 -39
  56. snowflake/cli/plugins/nativeapp/project_model.py +8 -4
  57. snowflake/cli/plugins/nativeapp/run_processor.py +117 -102
  58. snowflake/cli/plugins/nativeapp/teardown_processor.py +7 -2
  59. snowflake/cli/plugins/nativeapp/v2_conversions/v2_to_v1_decorator.py +146 -0
  60. snowflake/cli/plugins/nativeapp/version/commands.py +19 -3
  61. snowflake/cli/plugins/nativeapp/version/version_processor.py +11 -3
  62. snowflake/cli/plugins/object/commands.py +1 -1
  63. snowflake/cli/plugins/object/manager.py +2 -15
  64. snowflake/cli/plugins/snowpark/commands.py +34 -26
  65. snowflake/cli/plugins/snowpark/common.py +88 -27
  66. snowflake/cli/plugins/snowpark/manager.py +16 -5
  67. snowflake/cli/plugins/snowpark/models.py +6 -0
  68. snowflake/cli/plugins/sql/commands.py +3 -5
  69. snowflake/cli/plugins/sql/manager.py +1 -1
  70. snowflake/cli/plugins/stage/commands.py +2 -2
  71. snowflake/cli/plugins/stage/diff.py +4 -2
  72. snowflake/cli/plugins/stage/manager.py +290 -86
  73. snowflake/cli/plugins/streamlit/commands.py +20 -6
  74. snowflake/cli/plugins/streamlit/manager.py +29 -27
  75. snowflake/cli/plugins/workspace/__init__.py +13 -0
  76. snowflake/cli/plugins/workspace/commands.py +35 -0
  77. snowflake/cli/plugins/workspace/plugin_spec.py +30 -0
  78. snowflake/cli/templates/default_snowpark/app/__init__.py +0 -13
  79. snowflake/cli/templates/default_snowpark/app/common.py +0 -15
  80. snowflake/cli/templates/default_snowpark/app/functions.py +0 -14
  81. snowflake/cli/templates/default_snowpark/app/procedures.py +0 -14
  82. snowflake/cli/templates/default_streamlit/common/hello.py +0 -15
  83. snowflake/cli/templates/default_streamlit/pages/my_page.py +0 -14
  84. snowflake/cli/templates/default_streamlit/streamlit_app.py +0 -14
  85. {snowflake_cli_labs-2.6.0rc0.dist-info → snowflake_cli_labs-2.7.0rc0.dist-info}/METADATA +7 -6
  86. {snowflake_cli_labs-2.6.0rc0.dist-info → snowflake_cli_labs-2.7.0rc0.dist-info}/RECORD +89 -69
  87. {snowflake_cli_labs-2.6.0rc0.dist-info → snowflake_cli_labs-2.7.0rc0.dist-info}/WHEEL +0 -0
  88. {snowflake_cli_labs-2.6.0rc0.dist-info → snowflake_cli_labs-2.7.0rc0.dist-info}/entry_points.txt +0 -0
  89. {snowflake_cli_labs-2.6.0rc0.dist-info → snowflake_cli_labs-2.7.0rc0.dist-info}/licenses/LICENSE +0 -0
@@ -12,14 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from __future__ import annotations
16
+
17
+ import json
15
18
  import logging
16
19
 
17
20
  from click.exceptions import ClickException
18
21
  from snowflake.connector import SnowflakeConnection
19
22
  from snowflake.connector.cursor import DictCursor
20
23
 
21
- LOCAL_DEPLOYMENT: str = "us-west-2"
22
-
23
24
  log = logging.getLogger(__name__)
24
25
 
25
26
  REGIONLESS_QUERY = """
@@ -29,11 +30,23 @@ REGIONLESS_QUERY = """
29
30
  )) where value['name'] = 'UI_SNOWSIGHT_ENABLE_REGIONLESS_REDIRECT';
30
31
  """
31
32
 
33
+ ALLOWLIST_QUERY = "SELECT SYSTEM$ALLOWLIST()"
34
+ SNOWFLAKE_DEPLOYMENT = "SNOWFLAKE_DEPLOYMENT"
35
+ LOCAL_DEPLOYMENT_REGION: str = "us-west-2"
32
36
 
33
- class MissingConnectionHostError(ClickException):
37
+
38
+ class MissingConnectionAccountError(ClickException):
34
39
  def __init__(self, conn: SnowflakeConnection):
35
40
  super().__init__(
36
- f"The connection host ({conn.host}) was missing or not in "
41
+ "Could not determine account by system call, configured account name, or configured host. Connection: "
42
+ + repr(conn)
43
+ )
44
+
45
+
46
+ class MissingConnectionRegionError(ClickException):
47
+ def __init__(self, host: str | None):
48
+ super().__init__(
49
+ f"The connection host ({host}) was missing or not in "
37
50
  "the expected format "
38
51
  "(<account>.<deployment>.snowflakecomputing.com)"
39
52
  )
@@ -50,11 +63,60 @@ def is_regionless_redirect(conn: SnowflakeConnection) -> bool:
50
63
  *_, cursor = conn.execute_string(REGIONLESS_QUERY, cursor_class=DictCursor)
51
64
  return cursor.fetchone()["REGIONLESS"].lower() == "true"
52
65
  except:
53
- # by default, assume that
54
- log.warning("Cannot determine regionless redirect; assuming True.")
66
+ log.warning(
67
+ "Cannot determine regionless redirect; assuming True.", exc_info=True
68
+ )
55
69
  return True
56
70
 
57
71
 
72
+ def get_host_region(host: str) -> str | None:
73
+ """
74
+ Looks for hosts of form
75
+ <account>.[x.y.z].snowflakecomputing.com
76
+ Returns the three-part [region identifier] or None.
77
+ """
78
+ host_parts = host.split(".")
79
+ if host_parts[-1] == "local":
80
+ return LOCAL_DEPLOYMENT_REGION
81
+ elif len(host_parts) == 6:
82
+ return ".".join(host_parts[1:4])
83
+ return None
84
+
85
+
86
+ def guess_regioned_host_from_allowlist(conn: SnowflakeConnection) -> str | None:
87
+ """
88
+ Use SYSTEM$ALLOWLIST to find a regioned host (<account>.x.y.z.snowflakecomputing.com)
89
+ that corresponds to the given Snowflake connection object.
90
+ """
91
+ try:
92
+ *_, cursor = conn.execute_string(ALLOWLIST_QUERY, cursor_class=DictCursor)
93
+ allowlist_tuples = json.loads(cursor.fetchone()["SYSTEM$ALLOWLIST()"])
94
+ for t in allowlist_tuples:
95
+ if t["type"] == SNOWFLAKE_DEPLOYMENT:
96
+ if get_host_region(t["host"]) is not None:
97
+ return t["host"]
98
+ except:
99
+ log.warning(
100
+ "Could not call SYSTEM$ALLOWLIST; returning an empty guess.", exc_info=True
101
+ )
102
+ return None
103
+
104
+
105
+ def get_region(conn: SnowflakeConnection) -> str:
106
+ """
107
+ Get the region of the given connection, or raise MissingConnectionRegionError.
108
+ """
109
+ if conn.host:
110
+ if region := get_host_region(conn.host):
111
+ return region
112
+
113
+ if host := guess_regioned_host_from_allowlist(conn):
114
+ if region := get_host_region(host):
115
+ return region
116
+
117
+ raise MissingConnectionRegionError(host or conn.host)
118
+
119
+
58
120
  def get_context(conn: SnowflakeConnection) -> str:
59
121
  """
60
122
  Determines the first part of the path in a Snowsight URL.
@@ -67,14 +129,7 @@ def get_context(conn: SnowflakeConnection) -> str:
67
129
  )
68
130
  return cursor.fetchone()["SYSTEM$RETURN_CURRENT_ORG_NAME()"]
69
131
 
70
- host_parts = conn.host.split(".")
71
- if host_parts[-1] == "local":
72
- return LOCAL_DEPLOYMENT
73
-
74
- if len(host_parts) == 6:
75
- return ".".join(host_parts[1:4])
76
-
77
- raise MissingConnectionHostError(conn)
132
+ return get_region(conn)
78
133
 
79
134
 
80
135
  def get_account(conn: SnowflakeConnection) -> str:
@@ -91,11 +146,11 @@ def get_account(conn: SnowflakeConnection) -> str:
91
146
  if conn.account:
92
147
  return conn.account
93
148
 
94
- if not conn.host:
95
- raise MissingConnectionHostError(conn)
149
+ if conn.host:
150
+ host_parts = conn.host.split(".")
151
+ return host_parts[0]
96
152
 
97
- host_parts = conn.host.split(".")
98
- return host_parts[0]
153
+ raise MissingConnectionAccountError(conn)
99
154
 
100
155
 
101
156
  def get_snowsight_host(conn: SnowflakeConnection) -> str:
@@ -24,6 +24,7 @@ from click import UsageError
24
24
  from snowflake.cli.api.cli_global_context import cli_context
25
25
  from snowflake.cli.api.commands.flags import readable_file_option
26
26
  from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
27
+ from snowflake.cli.api.constants import PYTHON_3_12
27
28
  from snowflake.cli.api.output.types import (
28
29
  CollectionResult,
29
30
  CommandResult,
@@ -45,7 +46,7 @@ app = SnowTyperFactory(
45
46
  help="Provides access to Snowflake Cortex.",
46
47
  )
47
48
 
48
- SEARCH_COMMAND_ENABLED = sys.version_info < (3, 12)
49
+ SEARCH_COMMAND_ENABLED = sys.version_info < PYTHON_3_12
49
50
 
50
51
 
51
52
  @app.command(
@@ -14,15 +14,18 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ import itertools
17
18
  import logging
19
+ from os import path
20
+ from pathlib import Path
18
21
  from typing import List, Optional
19
22
 
20
23
  import typer
21
24
  from click import ClickException
22
25
  from snowflake.cli.api.commands.flags import (
26
+ ExecuteVariablesOption,
23
27
  OnErrorOption,
24
28
  PatternOption,
25
- VariablesOption,
26
29
  identifier_argument,
27
30
  like_option,
28
31
  )
@@ -37,7 +40,6 @@ from snowflake.cli.plugins.object.command_aliases import (
37
40
  scope_option,
38
41
  )
39
42
  from snowflake.cli.plugins.object.manager import ObjectManager
40
- from snowflake.cli.plugins.stage.commands import get
41
43
  from snowflake.cli.plugins.stage.manager import OnErrorType
42
44
 
43
45
  app = SnowTyperFactory(
@@ -264,7 +266,6 @@ def copy(
264
266
  )
265
267
  )
266
268
  return get(
267
- recursive=True,
268
269
  source_path=repository_path,
269
270
  destination_path=destination_path,
270
271
  parallel=parallel,
@@ -275,7 +276,7 @@ def copy(
275
276
  def execute(
276
277
  repository_path: str = RepoPathArgument,
277
278
  on_error: OnErrorType = OnErrorOption,
278
- variables: Optional[List[str]] = VariablesOption,
279
+ variables: Optional[List[str]] = ExecuteVariablesOption,
279
280
  **options,
280
281
  ):
281
282
  """
@@ -287,3 +288,18 @@ def execute(
287
288
  stage_path=repository_path, on_error=on_error, variables=variables
288
289
  )
289
290
  return CollectionResult(results)
291
+
292
+
293
+ def get(source_path: str, destination_path: str, parallel: int):
294
+ target = Path(destination_path).resolve()
295
+
296
+ cursors = GitManager().get_recursive(
297
+ stage_path=source_path, dest_path=target, parallel=parallel
298
+ )
299
+ results = [list(QueryResult(c).result) for c in cursors]
300
+ flattened_results = list(itertools.chain.from_iterable(results))
301
+ sorted_results = sorted(
302
+ flattened_results,
303
+ key=lambda e: (path.dirname(e["file"]), path.basename(e["file"])),
304
+ )
305
+ return CollectionResult(sorted_results)
@@ -12,13 +12,50 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from __future__ import annotations
16
+
15
17
  from pathlib import Path
16
18
  from textwrap import dedent
19
+ from typing import List
17
20
 
18
- from snowflake.cli.plugins.stage.manager import StageManager, StagePathParts
21
+ from snowflake.cli.plugins.stage.manager import (
22
+ USER_STAGE_PREFIX,
23
+ StageManager,
24
+ StagePathParts,
25
+ UserStagePathParts,
26
+ )
19
27
  from snowflake.connector.cursor import SnowflakeCursor
20
28
 
21
29
 
30
+ class GitStagePathParts(StagePathParts):
31
+ def __init__(self, stage_path: str):
32
+ self.stage = GitManager.get_stage_from_path(stage_path)
33
+ stage_path_parts = Path(stage_path).parts
34
+ git_repo_name = stage_path_parts[0].split(".")[-1]
35
+ if git_repo_name.startswith("@"):
36
+ git_repo_name = git_repo_name[1:]
37
+ self.stage_name = "/".join([git_repo_name, *stage_path_parts[1:3], ""])
38
+ self.directory = "/".join(stage_path_parts[3:])
39
+ self.is_directory = True if stage_path.endswith("/") else False
40
+
41
+ @property
42
+ def path(self) -> str:
43
+ return (
44
+ f"{self.stage_name}{self.directory}"
45
+ if self.stage_name.endswith("/")
46
+ else f"{self.stage_name}/{self.directory}"
47
+ )
48
+
49
+ def add_stage_prefix(self, file_path: str) -> str:
50
+ stage = Path(self.stage).parts[0]
51
+ file_path_without_prefix = Path(file_path).parts[1:]
52
+ return f"{stage}/{'/'.join(file_path_without_prefix)}"
53
+
54
+ def get_directory_from_file_path(self, file_path: str) -> List[str]:
55
+ stage_path_length = len(Path(self.directory).parts)
56
+ return list(Path(file_path).parts[3 + stage_path_length : -1])
57
+
58
+
22
59
  class GitManager(StageManager):
23
60
  def show_branches(self, repo_name: str, like: str) -> SnowflakeCursor:
24
61
  return self._execute_query(f"show git branches like '{like}' in {repo_name}")
@@ -51,22 +88,9 @@ class GitManager(StageManager):
51
88
  """
52
89
  return f"{'/'.join(Path(path).parts[0:3])}/"
53
90
 
54
- def _split_stage_path(self, stage_path: str) -> StagePathParts:
55
- """
56
- Splits Git repository path `@repo/branch/main/dir`
57
- stage -> @repo/branch/main/
58
- stage_name -> repo/branch/main/
59
- directory -> dir
60
- For Git repository with fully qualified name `@db.schema.repo/branch/main/dir`
61
- stage -> @db.schema.repo/branch/main/
62
- stage_name -> repo/branch/main/
63
- directory -> dir
64
- """
65
- stage = self.get_stage_from_path(stage_path)
66
- stage_path_parts = Path(stage_path).parts
67
- git_repo_name = stage_path_parts[0].split(".")[-1]
68
- if git_repo_name.startswith("@"):
69
- git_repo_name = git_repo_name[1:]
70
- stage_name = "/".join([git_repo_name, *stage_path_parts[1:3], ""])
71
- directory = "/".join(stage_path_parts[3:])
72
- return StagePathParts(stage, stage_name, directory)
91
+ @staticmethod
92
+ def _stage_path_part_factory(stage_path: str) -> StagePathParts:
93
+ stage_path = StageManager.get_standard_stage_prefix(stage_path)
94
+ if stage_path.startswith(USER_STAGE_PREFIX):
95
+ return UserStagePathParts(stage_path)
96
+ return GitStagePathParts(stage_path)
@@ -0,0 +1,13 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,242 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ from typing import Any, Dict, List, Optional
19
+
20
+ import typer
21
+ import yaml
22
+ from click import ClickException
23
+ from snowflake.cli.api.commands.flags import (
24
+ NoInteractiveOption,
25
+ parse_key_value_variables,
26
+ variables_option,
27
+ )
28
+ from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
29
+ from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB
30
+ from snowflake.cli.api.exceptions import InvalidTemplate
31
+ from snowflake.cli.api.output.types import (
32
+ CommandResult,
33
+ MessageResult,
34
+ )
35
+ from snowflake.cli.api.project.schemas.template import Template, TemplateVariable
36
+ from snowflake.cli.api.rendering.project_templates import render_template_files
37
+ from snowflake.cli.api.secure_path import SecurePath
38
+
39
+ # simple Typer with defaults because it won't become a command group as it contains only one command
40
+ app = SnowTyperFactory()
41
+
42
+
43
+ DEFAULT_SOURCE = "https://github.com/snowflakedb/snowflake-cli-templates"
44
+
45
+ log = logging.getLogger(__name__)
46
+
47
+
48
+ def _path_argument_callback(path: str) -> str:
49
+ if SecurePath(path).exists():
50
+ raise ClickException(
51
+ f"The directory {path} already exists. Please specify a different path for the project."
52
+ )
53
+ return path
54
+
55
+
56
+ PathArgument = typer.Argument(
57
+ ...,
58
+ help="Directory to be initialized with the project. This directory must not already exist",
59
+ show_default=False,
60
+ callback=_path_argument_callback,
61
+ )
62
+ TemplateOption = typer.Option(
63
+ None,
64
+ "--template",
65
+ help="which template (subdirectory of --template-source) should be used. If not provided,"
66
+ " whole source will be used as the template.",
67
+ show_default=False,
68
+ )
69
+ SourceOption = typer.Option(
70
+ default=DEFAULT_SOURCE,
71
+ help=f"local path to template directory or URL to git repository with templates.",
72
+ )
73
+ VariablesOption = variables_option(
74
+ "String in `key=value` format. Provided variables will not be prompted for."
75
+ )
76
+
77
+ TEMPLATE_METADATA_FILE_NAME = "template.yml"
78
+
79
+
80
+ def _fetch_local_template(
81
+ template_source: SecurePath, path: Optional[str], destination: SecurePath
82
+ ) -> SecurePath:
83
+ """Copies local template to [dest] and returns path to the template root.
84
+ Ends with an error of the template does not exist."""
85
+
86
+ template_source.assert_exists()
87
+ template_origin = template_source / path if path else template_source
88
+ log.info("Copying local template from %s", template_origin.path)
89
+ if not template_origin.exists():
90
+ raise ClickException(
91
+ f"Template '{path}' cannot be found under {template_source}"
92
+ )
93
+
94
+ template_origin.copy(destination.path)
95
+ return destination / template_origin.name
96
+
97
+
98
+ def _fetch_remote_template(
99
+ url: str, path: Optional[str], destination: SecurePath
100
+ ) -> SecurePath:
101
+ """Downloads remote repository template to [dest],
102
+ and returns path to the template root.
103
+ Ends with an error of the template does not exist."""
104
+ from git import GitCommandError
105
+ from git import rmtree as git_rmtree
106
+
107
+ # TODO: during nativeapp refactor get rid of this dependency
108
+ from snowflake.cli.plugins.nativeapp.utils import shallow_git_clone
109
+
110
+ log.info("Downloading remote template from %s", url)
111
+ try:
112
+ shallow_git_clone(url, to_path=destination.path)
113
+ except GitCommandError as err:
114
+ import re
115
+
116
+ if re.search("fatal: repository '.*' not found", err.stderr):
117
+ raise ClickException(f"Repository '{url}' does not exist")
118
+ raise
119
+
120
+ if path:
121
+ # template is a subdirectoruy of the repository
122
+ template_root = destination / path
123
+ else:
124
+ # template is a whole repository
125
+ # removing .git directory not to copy it to the template
126
+ template_root = destination
127
+ git_rmtree((template_root / ".git").path)
128
+ if not template_root.exists():
129
+ raise ClickException(f"Template '{path}' cannot be found under {url}")
130
+
131
+ return template_root
132
+
133
+
134
+ def _read_template_metadata(template_root: SecurePath) -> Template:
135
+ """Parse template.yml file."""
136
+ template_metadata_path = template_root / TEMPLATE_METADATA_FILE_NAME
137
+ log.debug("Reading template metadata from %s", template_metadata_path.path)
138
+ if not template_metadata_path.exists():
139
+ raise InvalidTemplate(
140
+ f"Template does not have {TEMPLATE_METADATA_FILE_NAME} file."
141
+ )
142
+ with template_metadata_path.open(read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB) as fd:
143
+ yaml_contents = yaml.safe_load(fd) or {}
144
+ return Template(template_root, **yaml_contents)
145
+
146
+
147
+ def _remove_template_metadata_file(template_root: SecurePath) -> None:
148
+ (template_root / TEMPLATE_METADATA_FILE_NAME).unlink()
149
+
150
+
151
+ def _determine_variable_values(
152
+ variables_metadata: List[TemplateVariable],
153
+ variables_from_flags: Dict[str, Any],
154
+ no_interactive: bool,
155
+ ) -> Dict[str, Any]:
156
+ """
157
+ Prompt user for values not provided in [variables_from_flags].
158
+ If [no_interactive] is True, fill not provided variables with their default values.
159
+ """
160
+ result = {}
161
+
162
+ log.debug(
163
+ "Resolving values of variables: %s",
164
+ ", ".join(v.name for v in variables_metadata),
165
+ )
166
+ for variable in variables_metadata:
167
+ if variable.name in variables_from_flags:
168
+ value = variable.python_type(variables_from_flags[variable.name])
169
+ else:
170
+ value = variable.prompt_user_for_value(no_interactive)
171
+
172
+ result[variable.name] = value
173
+
174
+ return result
175
+
176
+
177
+ def _validate_cli_version(required_version: str) -> None:
178
+ from packaging.version import parse
179
+ from snowflake.cli.__about__ import VERSION
180
+
181
+ if parse(required_version) > parse(VERSION):
182
+ raise ClickException(
183
+ f"Snowflake CLI version ({VERSION}) is too low - minimum version required"
184
+ f" by template is {required_version}. Please upgrade before continuing."
185
+ )
186
+
187
+
188
+ @app.command(no_args_is_help=True)
189
+ def init(
190
+ path: str = PathArgument,
191
+ template: Optional[str] = TemplateOption,
192
+ template_source: Optional[str] = SourceOption,
193
+ variables: Optional[List[str]] = VariablesOption,
194
+ no_interactive: bool = NoInteractiveOption,
195
+ **options,
196
+ ) -> CommandResult:
197
+ """
198
+ Creates project directory from template.
199
+ """
200
+ variables_from_flags = {
201
+ v.key: v.value for v in parse_key_value_variables(variables)
202
+ }
203
+ is_remote = any(
204
+ template_source.startswith(prefix) for prefix in ["git@", "http://", "https://"] # type: ignore
205
+ )
206
+
207
+ # copy/download template into tmpdir, so it is going to be removed in case command ends with an error
208
+ with SecurePath.temporary_directory() as tmpdir:
209
+ if is_remote:
210
+ template_root = _fetch_remote_template(
211
+ url=template_source, path=template, destination=tmpdir # type: ignore
212
+ )
213
+ else:
214
+ template_root = _fetch_local_template(
215
+ template_source=SecurePath(template_source),
216
+ path=template,
217
+ destination=tmpdir,
218
+ )
219
+
220
+ template_metadata = _read_template_metadata(template_root)
221
+ if template_metadata.minimum_cli_version:
222
+ _validate_cli_version(template_metadata.minimum_cli_version)
223
+
224
+ variable_values = _determine_variable_values(
225
+ variables_metadata=template_metadata.variables,
226
+ variables_from_flags=variables_from_flags,
227
+ no_interactive=no_interactive,
228
+ )
229
+ variable_values["project_dir_name"] = SecurePath(path).name
230
+ log.debug(
231
+ "Rendering template files: %s", ", ".join(template_metadata.files_to_render)
232
+ )
233
+ render_template_files(
234
+ template_root=template_root,
235
+ files_to_render=template_metadata.files_to_render,
236
+ data=variable_values,
237
+ )
238
+ _remove_template_metadata_file(template_root)
239
+ SecurePath(path).parent.mkdir(exist_ok=True, parents=True)
240
+ template_root.copy(path)
241
+
242
+ return MessageResult(f"Initialized the new project in {path}")
@@ -0,0 +1,30 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from snowflake.cli.api.plugins.command import (
16
+ SNOWCLI_ROOT_COMMAND_PATH,
17
+ CommandSpec,
18
+ CommandType,
19
+ plugin_hook_impl,
20
+ )
21
+ from snowflake.cli.plugins.init import commands
22
+
23
+
24
+ @plugin_hook_impl
25
+ def command_spec():
26
+ return CommandSpec(
27
+ parent_command_path=SNOWCLI_ROOT_COMMAND_PATH,
28
+ command_type=CommandType.SINGLE_COMMAND,
29
+ typer_instance=commands.app.create_instance(),
30
+ )
@@ -15,6 +15,7 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  from abc import ABC, abstractmethod
18
+ from pathlib import Path
18
19
  from typing import Optional
19
20
 
20
21
  from click import ClickException
@@ -34,6 +35,42 @@ class UnsupportedArtifactProcessorError(ClickException):
34
35
  )
35
36
 
36
37
 
38
+ def is_python_file_artifact(src: Path, _: Path):
39
+ """Determines whether the provided source path is an existing python file."""
40
+ return src.is_file() and src.suffix == ".py"
41
+
42
+
43
+ class ProjectFileContextManager:
44
+ """
45
+ A context manager that encapsulates the logic required to update a project file
46
+ in processor logic. The processor can use this manager to gain access to the contents
47
+ of a file, and optionally provide replacement contents. If it does, the file is
48
+ correctly modified in the deploy root directory to reflect the new contents.
49
+ """
50
+
51
+ def __init__(self, path: Path):
52
+ self.path = path
53
+ self._contents = None
54
+ self.edited_contents = None
55
+
56
+ @property
57
+ def contents(self):
58
+ return self._contents
59
+
60
+ def __enter__(self):
61
+ self._contents = self.path.read_text(encoding="utf-8")
62
+
63
+ return self
64
+
65
+ def __exit__(self, exc_type, exc_val, exc_tb):
66
+ if self.edited_contents is not None:
67
+ if self.path.is_symlink():
68
+ # if the file is a symlink, make sure we don't overwrite the original
69
+ self.path.unlink()
70
+
71
+ self.path.write_text(self.edited_contents, encoding="utf-8")
72
+
73
+
37
74
  class ArtifactProcessor(ABC):
38
75
  def __init__(
39
76
  self,
@@ -49,3 +86,6 @@ class ArtifactProcessor(ABC):
49
86
  **kwargs,
50
87
  ) -> None:
51
88
  pass
89
+
90
+ def edit_file(self, path: Path):
91
+ return ProjectFileContextManager(path)