snowflake-cli 3.5.0__py3-none-any.whl → 3.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. snowflake/cli/__about__.py +13 -1
  2. snowflake/cli/_app/commands_registration/builtin_plugins.py +4 -0
  3. snowflake/cli/_app/loggers.py +2 -2
  4. snowflake/cli/_app/snow_connector.py +7 -6
  5. snowflake/cli/_app/telemetry.py +3 -15
  6. snowflake/cli/_app/version_check.py +4 -4
  7. snowflake/cli/_plugins/auth/__init__.py +11 -0
  8. snowflake/cli/_plugins/auth/keypair/__init__.py +0 -0
  9. snowflake/cli/_plugins/auth/keypair/commands.py +151 -0
  10. snowflake/cli/_plugins/auth/keypair/manager.py +331 -0
  11. snowflake/cli/_plugins/auth/keypair/plugin_spec.py +30 -0
  12. snowflake/cli/_plugins/connection/commands.py +78 -1
  13. snowflake/cli/_plugins/helpers/commands.py +25 -1
  14. snowflake/cli/_plugins/helpers/snowsl_vars_reader.py +133 -0
  15. snowflake/cli/_plugins/init/commands.py +9 -6
  16. snowflake/cli/_plugins/logs/__init__.py +0 -0
  17. snowflake/cli/_plugins/logs/commands.py +105 -0
  18. snowflake/cli/_plugins/logs/manager.py +107 -0
  19. snowflake/cli/_plugins/logs/plugin_spec.py +16 -0
  20. snowflake/cli/_plugins/logs/utils.py +60 -0
  21. snowflake/cli/_plugins/nativeapp/entities/application.py +4 -1
  22. snowflake/cli/_plugins/nativeapp/sf_sql_facade.py +33 -6
  23. snowflake/cli/_plugins/notebook/commands.py +3 -0
  24. snowflake/cli/_plugins/notebook/notebook_entity.py +16 -27
  25. snowflake/cli/_plugins/object/command_aliases.py +3 -1
  26. snowflake/cli/_plugins/object/manager.py +4 -2
  27. snowflake/cli/_plugins/project/commands.py +89 -48
  28. snowflake/cli/_plugins/project/manager.py +57 -23
  29. snowflake/cli/_plugins/project/project_entity_model.py +22 -3
  30. snowflake/cli/_plugins/snowpark/commands.py +15 -2
  31. snowflake/cli/_plugins/spcs/compute_pool/commands.py +17 -5
  32. snowflake/cli/_plugins/sql/manager.py +43 -52
  33. snowflake/cli/_plugins/sql/source_reader.py +230 -0
  34. snowflake/cli/_plugins/stage/manager.py +25 -12
  35. snowflake/cli/_plugins/streamlit/commands.py +3 -0
  36. snowflake/cli/_plugins/streamlit/manager.py +19 -15
  37. snowflake/cli/api/artifacts/upload.py +30 -34
  38. snowflake/cli/api/artifacts/utils.py +8 -6
  39. snowflake/cli/api/cli_global_context.py +7 -2
  40. snowflake/cli/api/commands/decorators.py +11 -2
  41. snowflake/cli/api/commands/flags.py +35 -4
  42. snowflake/cli/api/commands/snow_typer.py +20 -2
  43. snowflake/cli/api/config.py +5 -3
  44. snowflake/cli/api/constants.py +2 -0
  45. snowflake/cli/api/entities/utils.py +29 -16
  46. snowflake/cli/api/errno.py +1 -0
  47. snowflake/cli/api/exceptions.py +75 -27
  48. snowflake/cli/api/feature_flags.py +1 -0
  49. snowflake/cli/api/identifiers.py +2 -0
  50. snowflake/cli/api/plugins/plugin_config.py +2 -2
  51. snowflake/cli/api/project/schemas/template.py +3 -3
  52. snowflake/cli/api/rendering/project_templates.py +3 -3
  53. snowflake/cli/api/rendering/sql_templates.py +2 -2
  54. snowflake/cli/api/rest_api.py +2 -3
  55. snowflake/cli/{_app → api}/secret.py +4 -1
  56. snowflake/cli/api/secure_path.py +16 -4
  57. snowflake/cli/api/sql_execution.py +8 -4
  58. snowflake/cli/api/utils/definition_rendering.py +14 -8
  59. snowflake/cli/api/utils/templating_functions.py +4 -4
  60. {snowflake_cli-3.5.0.dist-info → snowflake_cli-3.7.0.dist-info}/METADATA +11 -11
  61. {snowflake_cli-3.5.0.dist-info → snowflake_cli-3.7.0.dist-info}/RECORD +64 -52
  62. {snowflake_cli-3.5.0.dist-info → snowflake_cli-3.7.0.dist-info}/WHEEL +0 -0
  63. {snowflake_cli-3.5.0.dist-info → snowflake_cli-3.7.0.dist-info}/entry_points.txt +0 -0
  64. {snowflake_cli-3.5.0.dist-info → snowflake_cli-3.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -28,7 +28,7 @@ from snowflake.cli.api.commands.flags import (
28
28
  from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
29
29
  from snowflake.cli.api.commands.utils import parse_key_value_variables
30
30
  from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB
31
- from snowflake.cli.api.exceptions import InvalidTemplate
31
+ from snowflake.cli.api.exceptions import InvalidTemplateError
32
32
  from snowflake.cli.api.output.types import (
33
33
  CommandResult,
34
34
  MessageResult,
@@ -138,7 +138,7 @@ def _read_template_metadata(template_root: SecurePath, args_error_msg: str) -> T
138
138
  template_metadata_path = template_root / TEMPLATE_METADATA_FILE_NAME
139
139
  log.debug("Reading template metadata from %s", template_metadata_path.path)
140
140
  if not template_metadata_path.exists():
141
- raise InvalidTemplate(
141
+ raise InvalidTemplateError(
142
142
  f"File {TEMPLATE_METADATA_FILE_NAME} not found. {args_error_msg}"
143
143
  )
144
144
  with template_metadata_path.open(read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB) as fd:
@@ -201,8 +201,9 @@ def init(
201
201
  variables_from_flags = {
202
202
  v.key: v.value for v in parse_key_value_variables(variables)
203
203
  }
204
- is_remote = any(
205
- template_source.startswith(prefix) for prefix in ["git@", "http://", "https://"] # type: ignore
204
+ is_remote = template_source is not None and any(
205
+ template_source.startswith(prefix)
206
+ for prefix in ["git@", "http://", "https://"] # type: ignore
206
207
  )
207
208
  args_error_msg = f"Check whether {TemplateOption.param_decls[0]} and {SourceOption.param_decls[0]} arguments are correct."
208
209
 
@@ -210,11 +211,13 @@ def init(
210
211
  with SecurePath.temporary_directory() as tmpdir:
211
212
  if is_remote:
212
213
  template_root = _fetch_remote_template(
213
- url=template_source, path=template, destination=tmpdir # type: ignore
214
+ url=template_source, # type: ignore
215
+ path=template,
216
+ destination=tmpdir, # type: ignore
214
217
  )
215
218
  else:
216
219
  template_root = _fetch_local_template(
217
- template_source=SecurePath(template_source),
220
+ template_source=SecurePath(template_source), # type: ignore
218
221
  path=template,
219
222
  destination=tmpdir,
220
223
  )
File without changes
@@ -0,0 +1,105 @@
1
+ import itertools
2
+ from datetime import datetime
3
+ from typing import Generator, Iterable, Optional, cast
4
+
5
+ import typer
6
+ from click import ClickException
7
+ from snowflake.cli._plugins.logs.manager import LogsManager
8
+ from snowflake.cli._plugins.logs.utils import LOG_LEVELS, LogsQueryRow
9
+ from snowflake.cli._plugins.object.commands import NameArgument, ObjectArgument
10
+ from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
11
+ from snowflake.cli.api.exceptions import CliArgumentError
12
+ from snowflake.cli.api.identifiers import FQN
13
+ from snowflake.cli.api.output.types import (
14
+ CommandResult,
15
+ MessageResult,
16
+ StreamResult,
17
+ )
18
+
19
+ app = SnowTyperFactory()
20
+
21
+
22
+ @app.command(name="logs", requires_connection=True)
23
+ def get_logs(
24
+ object_type: str = ObjectArgument,
25
+ object_name: FQN = NameArgument,
26
+ from_: Optional[str] = typer.Option(
27
+ None,
28
+ "--from",
29
+ help="The start time of the logs to retrieve. Accepts all ISO8061 formats",
30
+ ),
31
+ to: Optional[str] = typer.Option(
32
+ None,
33
+ "--to",
34
+ help="The end time of the logs to retrieve. Accepts all ISO8061 formats",
35
+ ),
36
+ refresh_time: int = typer.Option(
37
+ None,
38
+ "--refresh",
39
+ help="If set, the logs will be streamed with the given refresh time in seconds",
40
+ ),
41
+ event_table: Optional[str] = typer.Option(
42
+ None,
43
+ "--table",
44
+ help="The table to query for logs. If not provided, the default table will be used",
45
+ ),
46
+ log_level: Optional[str] = typer.Option(
47
+ "INFO",
48
+ "--log-level",
49
+ help="The log level to filter by. If not provided, INFO will be used",
50
+ ),
51
+ **options,
52
+ ):
53
+ """
54
+ Retrieves logs for a given object.
55
+ """
56
+
57
+ if log_level and not log_level.upper() in LOG_LEVELS:
58
+ raise CliArgumentError(
59
+ f"Invalid log level. Please choose from {', '.join(LOG_LEVELS)}"
60
+ )
61
+
62
+ if refresh_time and to:
63
+ raise ClickException(
64
+ "You cannot set both --refresh and --to parameters. Please check the values"
65
+ )
66
+
67
+ from_time = get_datetime_from_string(from_, "--from") if from_ else None
68
+ to_time = get_datetime_from_string(to, "--to") if to else None
69
+
70
+ if refresh_time:
71
+ logs_stream: Iterable[LogsQueryRow] = LogsManager().stream_logs(
72
+ object_type=object_type,
73
+ object_name=object_name,
74
+ from_time=from_time,
75
+ refresh_time=refresh_time,
76
+ event_table=event_table,
77
+ log_level=log_level,
78
+ )
79
+ logs = itertools.chain(
80
+ (MessageResult(log.log_message) for logs in logs_stream for log in logs)
81
+ )
82
+ else:
83
+ logs_iterable: Iterable[LogsQueryRow] = LogsManager().get_logs(
84
+ object_type=object_type,
85
+ object_name=object_name,
86
+ from_time=from_time,
87
+ to_time=to_time,
88
+ event_table=event_table,
89
+ log_level=log_level,
90
+ )
91
+ logs = (MessageResult(log.log_message) for log in logs_iterable) # type: ignore
92
+
93
+ return StreamResult(cast(Generator[CommandResult, None, None], logs))
94
+
95
+
96
+ def get_datetime_from_string(
97
+ date_str: str,
98
+ name: Optional[str] = None,
99
+ ) -> datetime:
100
+ try:
101
+ return datetime.fromisoformat(date_str)
102
+ except ValueError:
103
+ raise ClickException(
104
+ f"Incorrect format for '{name}'. Please use one of approved ISO formats."
105
+ )
@@ -0,0 +1,107 @@
1
+ import time
2
+ from datetime import datetime
3
+ from textwrap import dedent
4
+ from typing import Iterable, List, Optional
5
+
6
+ from snowflake.cli._plugins.logs.utils import (
7
+ LogsQueryRow,
8
+ get_timestamp_query,
9
+ parse_log_levels_for_query,
10
+ sanitize_logs,
11
+ )
12
+ from snowflake.cli._plugins.object.commands import NameArgument, ObjectArgument
13
+ from snowflake.cli.api.identifiers import FQN
14
+ from snowflake.cli.api.sql_execution import SqlExecutionMixin
15
+ from snowflake.connector.cursor import SnowflakeCursor
16
+
17
+
18
+ class LogsManager(SqlExecutionMixin):
19
+ def stream_logs(
20
+ self,
21
+ refresh_time: int,
22
+ object_type: str = ObjectArgument,
23
+ object_name: FQN = NameArgument,
24
+ from_time: Optional[datetime] = None,
25
+ event_table: Optional[str] = None,
26
+ log_level: Optional[str] = "INFO",
27
+ ) -> Iterable[List[LogsQueryRow]]:
28
+ try:
29
+ previous_end = from_time
30
+
31
+ while True:
32
+ raw_logs = self.get_raw_logs(
33
+ object_type=object_type,
34
+ object_name=object_name,
35
+ from_time=previous_end,
36
+ to_time=None,
37
+ event_table=event_table,
38
+ log_level=log_level,
39
+ ).fetchall()
40
+
41
+ if raw_logs:
42
+ result = self.sanitize_logs(raw_logs)
43
+ yield result
44
+ if result:
45
+ previous_end = result[-1].timestamp
46
+ time.sleep(refresh_time)
47
+
48
+ except KeyboardInterrupt:
49
+ return
50
+
51
+ def get_logs(
52
+ self,
53
+ object_type: str = ObjectArgument,
54
+ object_name: FQN = NameArgument,
55
+ from_time: Optional[datetime] = None,
56
+ to_time: Optional[datetime] = None,
57
+ event_table: Optional[str] = None,
58
+ log_level: Optional[str] = "INFO",
59
+ ) -> Iterable[LogsQueryRow]:
60
+ """
61
+ Basic function to get a single batch of logs from the server
62
+ """
63
+
64
+ logs = self.get_raw_logs(
65
+ object_type=object_type,
66
+ object_name=object_name,
67
+ from_time=from_time,
68
+ to_time=to_time,
69
+ event_table=event_table,
70
+ log_level=log_level,
71
+ )
72
+
73
+ return sanitize_logs(logs)
74
+
75
+ def get_raw_logs(
76
+ self,
77
+ object_type: str = ObjectArgument,
78
+ object_name: FQN = NameArgument,
79
+ from_time: Optional[datetime] = None,
80
+ to_time: Optional[datetime] = None,
81
+ event_table: Optional[str] = None,
82
+ log_level: Optional[str] = "INFO",
83
+ ) -> SnowflakeCursor:
84
+
85
+ table = event_table if event_table else "SNOWFLAKE.TELEMETRY.EVENTS"
86
+
87
+ query = dedent(
88
+ f"""
89
+ SELECT
90
+ timestamp,
91
+ resource_attributes:"snow.database.name"::string as database_name,
92
+ resource_attributes:"snow.schema.name"::string as schema_name,
93
+ resource_attributes:"snow.{object_type}.name"::string as object_name,
94
+ record:severity_text::string as log_level,
95
+ value::string as log_message
96
+ FROM {table}
97
+ WHERE record_type = 'LOG'
98
+ AND (record:severity_text IN ({parse_log_levels_for_query((log_level))}) or record:severity_text is NULL )
99
+ AND object_name = '{object_name}'
100
+ {get_timestamp_query(from_time, to_time)}
101
+ ORDER BY timestamp;
102
+ """
103
+ ).strip()
104
+
105
+ result = self.execute_query(query)
106
+
107
+ return result
@@ -0,0 +1,16 @@
1
+ from snowflake.cli._plugins.logs import commands
2
+ from snowflake.cli.api.plugins.command import (
3
+ SNOWCLI_ROOT_COMMAND_PATH,
4
+ CommandSpec,
5
+ CommandType,
6
+ plugin_hook_impl,
7
+ )
8
+
9
+
10
+ @plugin_hook_impl
11
+ def command_spec():
12
+ return CommandSpec(
13
+ parent_command_path=SNOWCLI_ROOT_COMMAND_PATH,
14
+ command_type=CommandType.SINGLE_COMMAND,
15
+ typer_instance=commands.app.create_instance(),
16
+ )
@@ -0,0 +1,60 @@
1
+ from datetime import datetime
2
+ from typing import List, NamedTuple, Optional, Tuple
3
+
4
+ from snowflake.cli.api.exceptions import CliArgumentError, CliSqlError
5
+ from snowflake.connector.cursor import SnowflakeCursor
6
+
7
+ LOG_LEVELS = ["TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL"]
8
+
9
+ LogsQueryRow = NamedTuple(
10
+ "LogsQueryRow",
11
+ [
12
+ ("timestamp", datetime),
13
+ ("database_name", str),
14
+ ("schema_name", str),
15
+ ("object_name", str),
16
+ ("log_level", str),
17
+ ("log_message", str),
18
+ ],
19
+ )
20
+
21
+
22
+ def sanitize_logs(logs: SnowflakeCursor | List[Tuple]) -> List[LogsQueryRow]:
23
+ try:
24
+ return [LogsQueryRow(*log) for log in logs]
25
+ except TypeError:
26
+ raise CliSqlError(
27
+ "Logs table has incorrect format. Please check the logs_table in your database"
28
+ )
29
+
30
+
31
+ def get_timestamp_query(from_time: Optional[datetime], to_time: Optional[datetime]):
32
+ if from_time and to_time and from_time > to_time:
33
+ raise CliArgumentError(
34
+ "From_time cannot be later than to_time. Please check the values"
35
+ )
36
+ query = []
37
+
38
+ if from_time is not None:
39
+ query.append(f"AND timestamp >= TO_TIMESTAMP_LTZ('{from_time.isoformat()}')\n")
40
+
41
+ if to_time is not None:
42
+ query.append(f"AND timestamp <= TO_TIMESTAMP_LTZ('{to_time.isoformat()}')\n")
43
+
44
+ return "".join(query)
45
+
46
+
47
+ def get_log_levels(log_level: str):
48
+ if log_level.upper() not in LOG_LEVELS and log_level != "":
49
+ raise CliArgumentError(
50
+ f"Invalid log level. Please choose from {', '.join(LOG_LEVELS)}"
51
+ )
52
+
53
+ if log_level == "":
54
+ log_level = "INFO"
55
+
56
+ return LOG_LEVELS[LOG_LEVELS.index(log_level.upper()) :]
57
+
58
+
59
+ def parse_log_levels_for_query(log_level: str):
60
+ return ", ".join(f"'{level}'" for level in get_log_levels(log_level))
@@ -669,7 +669,7 @@ class ApplicationEntity(EntityBase[ApplicationEntityModel]):
669
669
  role_to_use=package.role,
670
670
  )
671
671
 
672
- return get_snowflake_facade().create_application(
672
+ create_app_result, warnings = get_snowflake_facade().create_application(
673
673
  name=self.name,
674
674
  package_name=package.name,
675
675
  install_method=install_method,
@@ -680,6 +680,9 @@ class ApplicationEntity(EntityBase[ApplicationEntityModel]):
680
680
  warehouse=self.warehouse,
681
681
  release_channel=release_channel,
682
682
  )
683
+ for warning in warnings:
684
+ self.console.warning(warning)
685
+ return create_app_result
683
686
 
684
687
  @span("update_app_object")
685
688
  def create_or_upgrade_app(
@@ -60,6 +60,7 @@ from snowflake.cli.api.errno import (
60
60
  CANNOT_DISABLE_MANDATORY_TELEMETRY,
61
61
  CANNOT_DISABLE_RELEASE_CHANNELS,
62
62
  CANNOT_MODIFY_RELEASE_CHANNEL_ACCOUNTS,
63
+ CANNOT_SET_DEBUG_MODE_WITH_MANIFEST_VERSION,
63
64
  DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED,
64
65
  DOES_NOT_EXIST_OR_NOT_AUTHORIZED,
65
66
  INSUFFICIENT_PRIVILEGES,
@@ -854,7 +855,7 @@ class SnowflakeSQLFacade:
854
855
  debug_mode: bool | None,
855
856
  should_authorize_event_sharing: bool | None,
856
857
  release_channel: str | None = None,
857
- ) -> list[tuple[str]]:
858
+ ) -> tuple[list[tuple[str]], list[str]]:
858
859
  """
859
860
  Creates a new application object using an application package,
860
861
  running the setup script of the application package
@@ -868,6 +869,7 @@ class SnowflakeSQLFacade:
868
869
  @param debug_mode: Whether to enable debug mode; None means not explicitly enabled or disabled
869
870
  @param should_authorize_event_sharing: Whether to enable event sharing; None means not explicitly enabled or disabled
870
871
  @param release_channel [Optional]: Release channel to use when creating the application
872
+ @return: a tuple containing the result of the create application query and possible warning messages
871
873
  """
872
874
  package_name = to_identifier(package_name)
873
875
  name = to_identifier(name)
@@ -875,11 +877,9 @@ class SnowflakeSQLFacade:
875
877
 
876
878
  # by default, applications are created in debug mode when possible;
877
879
  # this can be overridden in the project definition
878
- debug_mode_clause = ""
880
+ initial_debug_mode = False
879
881
  if install_method.is_dev_mode:
880
882
  initial_debug_mode = debug_mode if debug_mode is not None else True
881
- debug_mode_clause = f"debug_mode = {initial_debug_mode}"
882
-
883
883
  authorize_telemetry_clause = ""
884
884
  if should_authorize_event_sharing is not None:
885
885
  self._log.info(
@@ -903,13 +903,13 @@ class SnowflakeSQLFacade:
903
903
  from application package {package_name}
904
904
  {using_clause}
905
905
  {release_channel_clause}
906
- {debug_mode_clause}
907
906
  {authorize_telemetry_clause}
908
907
  comment = {SPECIAL_COMMENT}
909
908
  """
910
909
  )
911
910
  ),
912
911
  )
912
+
913
913
  except Exception as err:
914
914
  if isinstance(err, ProgrammingError):
915
915
  if err.errno == APPLICATION_REQUIRES_TELEMETRY_SHARING:
@@ -927,9 +927,36 @@ class SnowflakeSQLFacade:
927
927
  f"Failed to create application {name} with the following error message:\n"
928
928
  f"{err.msg}"
929
929
  ) from err
930
+
930
931
  handle_unclassified_error(err, f"Failed to create application {name}.")
931
932
 
932
- return create_cursor.fetchall()
933
+ warnings = []
934
+ try:
935
+ if initial_debug_mode:
936
+ self._sql_executor.execute_query(
937
+ dedent(
938
+ _strip_empty_lines(
939
+ f"""\
940
+ alter application {name}
941
+ set debug_mode = {initial_debug_mode}
942
+ """
943
+ )
944
+ )
945
+ )
946
+ except Exception as err:
947
+ if (
948
+ isinstance(err, ProgrammingError)
949
+ and err.errno == CANNOT_SET_DEBUG_MODE_WITH_MANIFEST_VERSION
950
+ ):
951
+ warnings.append(
952
+ "Did not apply debug mode to application because the manifest version is set to 2 or higher. Please use session debugging instead."
953
+ )
954
+ else:
955
+ warnings.append(
956
+ f"Failed to set debug mode for application {name}. {str(err)}"
957
+ )
958
+
959
+ return create_cursor.fetchall(), warnings
933
960
 
934
961
  def create_application_package(
935
962
  self,
@@ -23,6 +23,7 @@ from snowflake.cli._plugins.workspace.manager import WorkspaceManager
23
23
  from snowflake.cli.api.cli_global_context import get_cli_context
24
24
  from snowflake.cli.api.commands.decorators import with_project_definition
25
25
  from snowflake.cli.api.commands.flags import (
26
+ PruneOption,
26
27
  ReplaceOption,
27
28
  entity_argument,
28
29
  identifier_argument,
@@ -108,6 +109,7 @@ def deploy(
108
109
  help="Replace notebook object if it already exists. It only uploads new and overwrites existing files, "
109
110
  "but does not remove any files already on the stage.",
110
111
  ),
112
+ prune: bool = PruneOption(),
111
113
  **options,
112
114
  ) -> CommandResult:
113
115
  """Uploads a notebook and required files to a stage and creates a Snowflake notebook."""
@@ -132,6 +134,7 @@ def deploy(
132
134
  notebook.entity_id,
133
135
  EntityActions.DEPLOY,
134
136
  replace=replace,
137
+ prune=prune,
135
138
  )
136
139
  return MessageResult(
137
140
  f"Notebook successfully deployed and available under {notebook_url}"
@@ -4,9 +4,8 @@ from click import ClickException
4
4
  from snowflake.cli._plugins.connection.util import make_snowsight_url
5
5
  from snowflake.cli._plugins.notebook.notebook_entity_model import NotebookEntityModel
6
6
  from snowflake.cli._plugins.notebook.notebook_project_paths import NotebookProjectPaths
7
- from snowflake.cli._plugins.stage.manager import StageManager
8
7
  from snowflake.cli._plugins.workspace.context import ActionContext
9
- from snowflake.cli.api.artifacts.utils import bundle_artifacts
8
+ from snowflake.cli.api.artifacts.upload import sync_artifacts_with_stage
10
9
  from snowflake.cli.api.cli_global_context import get_cli_context
11
10
  from snowflake.cli.api.console.console import cli_console
12
11
  from snowflake.cli.api.entities.common import EntityBase
@@ -22,12 +21,15 @@ class NotebookEntity(EntityBase[NotebookEntityModel]):
22
21
  A notebook.
23
22
  """
24
23
 
24
+ @property
25
+ def _stage_path_from_model(self) -> str:
26
+ if self.model.stage_path is None:
27
+ return f"{_DEFAULT_NOTEBOOK_STAGE_NAME}/{self.fqn.name}"
28
+ return self.model.stage_path
29
+
25
30
  @functools.cached_property
26
31
  def _stage_path(self) -> StagePath:
27
- stage_path = self.model.stage_path
28
- if stage_path is None:
29
- stage_path = f"{_DEFAULT_NOTEBOOK_STAGE_NAME}/{self.fqn.name}"
30
- return StagePath.from_stage_str(stage_path)
32
+ return StagePath.from_stage_str(self._stage_path_from_model)
31
33
 
32
34
  @functools.cached_property
33
35
  def _project_paths(self):
@@ -41,26 +43,6 @@ class NotebookEntity(EntityBase[NotebookEntityModel]):
41
43
  except ProgrammingError:
42
44
  return False
43
45
 
44
- def _upload_artifacts(self):
45
- stage_fqn = self._stage_path.stage_fqn
46
- stage_manager = StageManager()
47
- cli_console.step(f"Creating stage {stage_fqn} if not exists")
48
- stage_manager.create(fqn=stage_fqn)
49
-
50
- cli_console.step("Uploading artifacts")
51
-
52
- # creating bundle map to handle glob patterns logic
53
- bundle_map = bundle_artifacts(self._project_paths, self.model.artifacts)
54
- for absolute_src, absolute_dest in bundle_map.all_mappings(
55
- absolute=True, expand_directories=True
56
- ):
57
- artifact_stage_path = self._stage_path / (
58
- absolute_dest.relative_to(self._project_paths.bundle_root).parent
59
- )
60
- stage_manager.put(
61
- local_path=absolute_src, stage_path=artifact_stage_path, overwrite=True
62
- )
63
-
64
46
  def get_create_sql(self, replace: bool) -> str:
65
47
  main_file_stage_path = self._stage_path / (
66
48
  self.model.notebook_file.absolute().relative_to(
@@ -99,6 +81,7 @@ class NotebookEntity(EntityBase[NotebookEntityModel]):
99
81
  self,
100
82
  action_ctx: ActionContext,
101
83
  replace: bool,
84
+ prune: bool,
102
85
  *args,
103
86
  **kwargs,
104
87
  ) -> str:
@@ -108,7 +91,13 @@ class NotebookEntity(EntityBase[NotebookEntityModel]):
108
91
  f"Notebook {self.fqn.name} already exists. Consider using --replace."
109
92
  )
110
93
  with cli_console.phase(f"Uploading artifacts to {self._stage_path}"):
111
- self._upload_artifacts()
94
+ sync_artifacts_with_stage(
95
+ project_paths=self._project_paths,
96
+ stage_root=self._stage_path_from_model,
97
+ prune=prune,
98
+ artifacts=self.model.artifacts,
99
+ )
100
+
112
101
  with cli_console.phase(f"Creating notebook {self.fqn}"):
113
102
  return self.action_create(replace=replace)
114
103
 
@@ -36,8 +36,10 @@ def add_object_command_aliases(
36
36
  name_argument: typer.Argument,
37
37
  like_option: Optional[typer.Option],
38
38
  scope_option: Optional[typer.Option],
39
- ommit_commands: List[str] = [],
39
+ ommit_commands: Optional[List[str]] = None,
40
40
  ):
41
+ if ommit_commands is None:
42
+ ommit_commands = list()
41
43
  if "list" not in ommit_commands:
42
44
  if not like_option:
43
45
  raise ClickException('[like_option] have to be defined for "list" command')
@@ -58,14 +58,16 @@ class ObjectManager(SqlExecutionMixin):
58
58
  object_name = _get_object_names(object_type).sf_name
59
59
  return self.execute_query(f"drop {object_name} {fqn.sql_identifier}")
60
60
 
61
- def describe(self, *, object_type: str, fqn: FQN):
61
+ def describe(self, *, object_type: str, fqn: FQN, **kwargs):
62
62
  # Image repository is the only supported object that does not have a DESCRIBE command.
63
63
  if object_type == "image-repository":
64
64
  raise ClickException(
65
65
  f"Describe is currently not supported for object of type image-repository"
66
66
  )
67
67
  object_name = _get_object_names(object_type).sf_name
68
- return self.execute_query(f"describe {object_name} {fqn.sql_identifier}")
68
+ return self.execute_query(
69
+ f"describe {object_name} {fqn.sql_identifier}", **kwargs
70
+ )
69
71
 
70
72
  def object_exists(self, *, object_type: str, fqn: FQN):
71
73
  try: