snowflake-cli-labs 2.6.1__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. snowflake/cli/__about__.py +1 -1
  2. snowflake/cli/api/cli_global_context.py +9 -0
  3. snowflake/cli/api/commands/decorators.py +9 -4
  4. snowflake/cli/api/commands/execution_metadata.py +40 -0
  5. snowflake/cli/api/commands/flags.py +45 -36
  6. snowflake/cli/api/commands/project_initialisation.py +5 -2
  7. snowflake/cli/api/commands/snow_typer.py +20 -9
  8. snowflake/cli/api/config.py +1 -0
  9. snowflake/cli/api/errno.py +27 -0
  10. snowflake/cli/api/feature_flags.py +5 -0
  11. snowflake/cli/api/identifiers.py +20 -3
  12. snowflake/cli/api/output/types.py +9 -0
  13. snowflake/cli/api/project/definition_manager.py +2 -2
  14. snowflake/cli/api/project/project_verification.py +23 -0
  15. snowflake/cli/api/project/schemas/entities/application_entity.py +50 -0
  16. snowflake/cli/api/project/schemas/entities/application_package_entity.py +63 -0
  17. snowflake/cli/api/project/schemas/entities/common.py +85 -0
  18. snowflake/cli/api/project/schemas/entities/entities.py +30 -0
  19. snowflake/cli/api/project/schemas/project_definition.py +114 -22
  20. snowflake/cli/api/project/schemas/streamlit/streamlit.py +5 -4
  21. snowflake/cli/api/project/schemas/template.py +77 -0
  22. snowflake/cli/{plugins/nativeapp/errno.py → api/rendering/__init__.py} +0 -2
  23. snowflake/cli/api/{utils/rendering.py → rendering/jinja.py} +3 -48
  24. snowflake/cli/api/rendering/project_definition_templates.py +39 -0
  25. snowflake/cli/api/rendering/project_templates.py +97 -0
  26. snowflake/cli/api/rendering/sql_templates.py +56 -0
  27. snowflake/cli/api/sql_execution.py +40 -1
  28. snowflake/cli/api/utils/definition_rendering.py +8 -5
  29. snowflake/cli/app/commands_registration/builtin_plugins.py +4 -0
  30. snowflake/cli/app/dev/docs/project_definition_docs_generator.py +2 -2
  31. snowflake/cli/app/loggers.py +3 -1
  32. snowflake/cli/app/printing.py +17 -7
  33. snowflake/cli/app/snow_connector.py +9 -1
  34. snowflake/cli/app/telemetry.py +41 -2
  35. snowflake/cli/plugins/connection/commands.py +13 -3
  36. snowflake/cli/plugins/connection/util.py +73 -18
  37. snowflake/cli/plugins/cortex/commands.py +2 -1
  38. snowflake/cli/plugins/git/commands.py +20 -4
  39. snowflake/cli/plugins/git/manager.py +44 -20
  40. snowflake/cli/plugins/init/__init__.py +13 -0
  41. snowflake/cli/plugins/init/commands.py +242 -0
  42. snowflake/cli/plugins/init/plugin_spec.py +30 -0
  43. snowflake/cli/plugins/nativeapp/codegen/artifact_processor.py +40 -0
  44. snowflake/cli/plugins/nativeapp/codegen/compiler.py +57 -27
  45. snowflake/cli/plugins/nativeapp/codegen/sandbox.py +99 -10
  46. snowflake/cli/plugins/nativeapp/codegen/setup/native_app_setup_processor.py +172 -0
  47. snowflake/cli/plugins/nativeapp/codegen/setup/setup_driver.py.source +56 -0
  48. snowflake/cli/plugins/nativeapp/codegen/snowpark/python_processor.py +21 -21
  49. snowflake/cli/plugins/nativeapp/commands.py +100 -6
  50. snowflake/cli/plugins/nativeapp/constants.py +0 -6
  51. snowflake/cli/plugins/nativeapp/exceptions.py +37 -12
  52. snowflake/cli/plugins/nativeapp/init.py +1 -1
  53. snowflake/cli/plugins/nativeapp/manager.py +114 -39
  54. snowflake/cli/plugins/nativeapp/project_model.py +8 -4
  55. snowflake/cli/plugins/nativeapp/run_processor.py +117 -102
  56. snowflake/cli/plugins/nativeapp/teardown_processor.py +7 -2
  57. snowflake/cli/plugins/nativeapp/v2_conversions/v2_to_v1_decorator.py +146 -0
  58. snowflake/cli/plugins/nativeapp/version/commands.py +19 -3
  59. snowflake/cli/plugins/nativeapp/version/version_processor.py +11 -3
  60. snowflake/cli/plugins/snowpark/commands.py +34 -26
  61. snowflake/cli/plugins/snowpark/common.py +88 -27
  62. snowflake/cli/plugins/snowpark/manager.py +16 -5
  63. snowflake/cli/plugins/snowpark/models.py +6 -0
  64. snowflake/cli/plugins/sql/commands.py +3 -5
  65. snowflake/cli/plugins/sql/manager.py +1 -1
  66. snowflake/cli/plugins/stage/commands.py +2 -2
  67. snowflake/cli/plugins/stage/diff.py +27 -64
  68. snowflake/cli/plugins/stage/manager.py +290 -86
  69. snowflake/cli/plugins/stage/md5.py +160 -0
  70. snowflake/cli/plugins/streamlit/commands.py +20 -6
  71. snowflake/cli/plugins/streamlit/manager.py +46 -32
  72. snowflake/cli/plugins/workspace/__init__.py +13 -0
  73. snowflake/cli/plugins/workspace/commands.py +35 -0
  74. snowflake/cli/plugins/workspace/plugin_spec.py +30 -0
  75. snowflake/cli/templates/default_snowpark/app/__init__.py +0 -13
  76. snowflake/cli/templates/default_snowpark/app/common.py +0 -15
  77. snowflake/cli/templates/default_snowpark/app/functions.py +0 -14
  78. snowflake/cli/templates/default_snowpark/app/procedures.py +0 -14
  79. snowflake/cli/templates/default_streamlit/common/hello.py +0 -15
  80. snowflake/cli/templates/default_streamlit/pages/my_page.py +0 -14
  81. snowflake/cli/templates/default_streamlit/streamlit_app.py +0 -14
  82. {snowflake_cli_labs-2.6.1.dist-info → snowflake_cli_labs-2.7.0.dist-info}/METADATA +7 -6
  83. {snowflake_cli_labs-2.6.1.dist-info → snowflake_cli_labs-2.7.0.dist-info}/RECORD +86 -65
  84. {snowflake_cli_labs-2.6.1.dist-info → snowflake_cli_labs-2.7.0.dist-info}/WHEEL +0 -0
  85. {snowflake_cli_labs-2.6.1.dist-info → snowflake_cli_labs-2.7.0.dist-info}/entry_points.txt +0 -0
  86. {snowflake_cli_labs-2.6.1.dist-info → snowflake_cli_labs-2.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,56 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Dict, Optional
18
+
19
+ from click import ClickException
20
+ from jinja2 import StrictUndefined, loaders
21
+ from snowflake.cli.api.cli_global_context import cli_context
22
+ from snowflake.cli.api.rendering.jinja import (
23
+ CONTEXT_KEY,
24
+ IgnoreAttrEnvironment,
25
+ env_bootstrap,
26
+ )
27
+
28
+ _SQL_TEMPLATE_START = "&{"
29
+ _SQL_TEMPLATE_END = "}"
30
+
31
+
32
+ def get_sql_cli_jinja_env(*, loader: Optional[loaders.BaseLoader] = None):
33
+ _random_block = "___very___unique___block___to___disable___logic___blocks___"
34
+ return env_bootstrap(
35
+ IgnoreAttrEnvironment(
36
+ loader=loader or loaders.BaseLoader(),
37
+ keep_trailing_newline=True,
38
+ variable_start_string=_SQL_TEMPLATE_START,
39
+ variable_end_string=_SQL_TEMPLATE_END,
40
+ block_start_string=_random_block,
41
+ block_end_string=_random_block,
42
+ undefined=StrictUndefined,
43
+ )
44
+ )
45
+
46
+
47
+ def snowflake_sql_jinja_render(content: str, data: Dict | None = None) -> str:
48
+ data = data or {}
49
+ if CONTEXT_KEY in data:
50
+ raise ClickException(
51
+ f"{CONTEXT_KEY} in user defined data. The `{CONTEXT_KEY}` variable is reserved for CLI usage."
52
+ )
53
+
54
+ context_data = cli_context.template_context
55
+ context_data.update(data)
56
+ return get_sql_cli_jinja_env().from_string(content).render(**context_data)
@@ -41,12 +41,22 @@ from snowflake.connector.errors import ProgrammingError
41
41
 
42
42
  class SqlExecutionMixin:
43
43
  def __init__(self):
44
- pass
44
+ self._snowpark_session = None
45
45
 
46
46
  @property
47
47
  def _conn(self):
48
48
  return cli_context.connection
49
49
 
50
+ @property
51
+ def snowpark_session(self):
52
+ if not self._snowpark_session:
53
+ from snowflake.snowpark.session import Session
54
+
55
+ self._snowpark_session = Session.builder.configs(
56
+ {"connection": self._conn}
57
+ ).create()
58
+ return self._snowpark_session
59
+
50
60
  @cached_property
51
61
  def _log(self):
52
62
  return logging.getLogger(__name__)
@@ -107,6 +117,35 @@ class SqlExecutionMixin:
107
117
  if is_different_role:
108
118
  self._execute_query(f"use role {prev_role}")
109
119
 
120
+ @contextmanager
121
+ def use_warehouse(self, new_wh: str):
122
+ """
123
+ Switches to a different warehouse for a while, then switches back.
124
+ This is a no-op if the requested warehouse is already active.
125
+ If there is no default warehouse in the account, it will throw an error.
126
+ """
127
+
128
+ wh_result = self._execute_query(
129
+ f"select current_warehouse()", cursor_class=DictCursor
130
+ ).fetchone()
131
+ # If user has an assigned default warehouse, prev_wh will contain a value even if the warehouse is suspended.
132
+ try:
133
+ prev_wh = wh_result["CURRENT_WAREHOUSE()"]
134
+ except:
135
+ prev_wh = None
136
+
137
+ # new_wh is not None, and should already be a valid identifier, no additional check is performed here.
138
+ is_different_wh = new_wh != prev_wh
139
+ try:
140
+ if is_different_wh:
141
+ self._log.debug("Using warehouse: %s", new_wh)
142
+ self.use(object_type=ObjectType.WAREHOUSE, name=new_wh)
143
+ yield
144
+ finally:
145
+ if prev_wh and is_different_wh:
146
+ self._log.debug("Switching back to warehouse: %s", prev_wh)
147
+ self.use(object_type=ObjectType.WAREHOUSE, name=prev_wh)
148
+
110
149
  def create_password_secret(
111
150
  self, name: str, username: str, password: str
112
151
  ) -> SnowflakeCursor:
@@ -22,13 +22,16 @@ from packaging.version import Version
22
22
  from snowflake.cli.api.console import cli_console as cc
23
23
  from snowflake.cli.api.exceptions import CycleDetectedError, InvalidTemplate
24
24
  from snowflake.cli.api.project.schemas.project_definition import (
25
- ProjectDefinition,
26
25
  ProjectProperties,
26
+ build_project_definition,
27
+ )
28
+ from snowflake.cli.api.rendering.jinja import CONTEXT_KEY
29
+ from snowflake.cli.api.rendering.project_definition_templates import (
30
+ get_project_definition_cli_jinja_env,
27
31
  )
28
32
  from snowflake.cli.api.utils.dict_utils import traverse
29
33
  from snowflake.cli.api.utils.graph import Graph, Node
30
34
  from snowflake.cli.api.utils.models import ProjectEnvironment
31
- from snowflake.cli.api.utils.rendering import CONTEXT_KEY, get_snowflake_cli_jinja_env
32
35
  from snowflake.cli.api.utils.types import Context, Definition
33
36
 
34
37
 
@@ -286,7 +289,7 @@ def render_definition_template(
286
289
  return ProjectProperties(None, {CONTEXT_KEY: {"env": environment_overrides}})
287
290
 
288
291
  project_context = {CONTEXT_KEY: definition}
289
- template_env = TemplatedEnvironment(get_snowflake_cli_jinja_env())
292
+ template_env = TemplatedEnvironment(get_project_definition_cli_jinja_env())
290
293
 
291
294
  if "definition_version" not in definition or Version(
292
295
  definition["definition_version"]
@@ -301,7 +304,7 @@ def render_definition_template(
301
304
  # also warn on Exception, as it means the user is incorrectly attempting to use templating
302
305
  _template_version_warning()
303
306
 
304
- project_definition = ProjectDefinition(**original_definition)
307
+ project_definition = build_project_definition(**original_definition)
305
308
  project_context[CONTEXT_KEY]["env"] = environment_overrides
306
309
  return ProjectProperties(project_definition, project_context)
307
310
 
@@ -337,5 +340,5 @@ def render_definition_template(
337
340
 
338
341
  definition["env"] = ProjectEnvironment(default_env, override_env)
339
342
  project_context[CONTEXT_KEY] = definition
340
- project_definition = ProjectDefinition(**definition)
343
+ project_definition = build_project_definition(**definition)
341
344
  return ProjectProperties(project_definition, project_context)
@@ -15,6 +15,7 @@
15
15
  from snowflake.cli.plugins.connection import plugin_spec as connection_plugin_spec
16
16
  from snowflake.cli.plugins.cortex import plugin_spec as cortex_plugin_spec
17
17
  from snowflake.cli.plugins.git import plugin_spec as git_plugin_spec
18
+ from snowflake.cli.plugins.init import plugin_spec as init_plugin_spec
18
19
  from snowflake.cli.plugins.nativeapp import plugin_spec as nativeapp_plugin_spec
19
20
  from snowflake.cli.plugins.notebook import plugin_spec as notebook_plugin_spec
20
21
  from snowflake.cli.plugins.object import plugin_spec as object_plugin_spec
@@ -28,6 +29,7 @@ from snowflake.cli.plugins.spcs import plugin_spec as spcs_plugin_spec
28
29
  from snowflake.cli.plugins.sql import plugin_spec as sql_plugin_spec
29
30
  from snowflake.cli.plugins.stage import plugin_spec as stage_plugin_spec
30
31
  from snowflake.cli.plugins.streamlit import plugin_spec as streamlit_plugin_spec
32
+ from snowflake.cli.plugins.workspace import plugin_spec as workspace_plugin_spec
31
33
 
32
34
 
33
35
  # plugin name to plugin spec
@@ -45,6 +47,8 @@ def get_builtin_plugin_name_to_plugin_spec():
45
47
  "notebook": notebook_plugin_spec,
46
48
  "object-stage-deprecated": object_stage_deprecated_plugin_spec,
47
49
  "cortex": cortex_plugin_spec,
50
+ "init": init_plugin_spec,
51
+ "workspace": workspace_plugin_spec,
48
52
  }
49
53
 
50
54
  return plugin_specs
@@ -18,7 +18,7 @@ import logging
18
18
  from typing import Any, Dict
19
19
 
20
20
  from pydantic.json_schema import model_json_schema
21
- from snowflake.cli.api.project.schemas.project_definition import ProjectDefinition
21
+ from snowflake.cli.api.project.schemas.project_definition import DefinitionV11
22
22
  from snowflake.cli.api.secure_path import SecurePath
23
23
  from snowflake.cli.app.dev.docs.project_definition_generate_json_schema import (
24
24
  ProjectDefinitionGenerateJsonSchema,
@@ -39,7 +39,7 @@ def generate_project_definition_docs(root: SecurePath):
39
39
 
40
40
  root.mkdir(exist_ok=True)
41
41
  list_of_sections = model_json_schema(
42
- ProjectDefinition, schema_generator=ProjectDefinitionGenerateJsonSchema
42
+ DefinitionV11, schema_generator=ProjectDefinitionGenerateJsonSchema
43
43
  )["result"]
44
44
  for section in list_of_sections:
45
45
  _render_definition_description(root, section)
@@ -183,7 +183,9 @@ def create_loggers(verbose: bool, debug: bool):
183
183
  else:
184
184
  # We need to remove handler definition - otherwise it creates file even if `save_logs` is False
185
185
  del config.handlers["file"]
186
- config.loggers["snowflake.cli"].handlers.remove("file")
186
+ for logger in config.loggers.values():
187
+ if "file" in logger.handlers:
188
+ logger.handlers.remove("file")
187
189
 
188
190
  config.loggers["snowflake.cli"].level = global_log_level
189
191
  config.loggers["snowflake"].level = global_log_level
@@ -34,6 +34,7 @@ from snowflake.cli.api.output.types import (
34
34
  MessageResult,
35
35
  MultipleResults,
36
36
  ObjectResult,
37
+ StreamResult,
37
38
  )
38
39
  from snowflake.cli.api.sanitizers import sanitize_for_terminal
39
40
 
@@ -89,7 +90,7 @@ def _print_multiple_table_results(obj: CollectionResult):
89
90
  for item in items:
90
91
  table.add_row(*[str(i) for i in item.values()])
91
92
  # Add separator between tables
92
- rich_print()
93
+ rich_print(flush=True)
93
94
 
94
95
 
95
96
  def is_structured_format(output_format):
@@ -98,12 +99,21 @@ def is_structured_format(output_format):
98
99
 
99
100
  def print_structured(result: CommandResult):
100
101
  """Handles outputs like json, yml and other structured and parsable formats."""
102
+ printed_end_line = False
101
103
  if isinstance(result, MultipleResults):
102
104
  _stream_json(result)
105
+ elif isinstance(result, StreamResult):
106
+ # A StreamResult prints each value onto its own line
107
+ # instead of joining all the values into a JSON array
108
+ for r in result.result:
109
+ json.dump(r, sys.stdout, cls=CustomJSONEncoder)
110
+ print(flush=True)
111
+ printed_end_line = True
103
112
  else:
104
113
  json.dump(result, sys.stdout, cls=CustomJSONEncoder, indent=4)
105
114
  # Adds empty line at the end
106
- print()
115
+ if not printed_end_line:
116
+ print(flush=True)
107
117
 
108
118
 
109
119
  def _stream_json(result):
@@ -130,11 +140,11 @@ def _stream_json(result):
130
140
  def print_unstructured(obj: CommandResult | None):
131
141
  """Handles outputs like table, plain text and other unstructured types."""
132
142
  if not obj:
133
- rich_print("Done")
143
+ rich_print("Done", flush=True)
134
144
  elif not obj.result:
135
- rich_print("No data")
145
+ rich_print("No data", flush=True)
136
146
  elif isinstance(obj, MessageResult):
137
- rich_print(sanitize_for_terminal(obj.message))
147
+ rich_print(sanitize_for_terminal(obj.message), flush=True)
138
148
  else:
139
149
  if isinstance(obj, ObjectResult):
140
150
  _print_single_table(obj)
@@ -152,14 +162,14 @@ def _print_single_table(obj):
152
162
  table.add_row(
153
163
  sanitize_for_terminal(str(key)), sanitize_for_terminal(str(value))
154
164
  )
155
- rich_print(table)
165
+ rich_print(table, flush=True)
156
166
 
157
167
 
158
168
  def print_result(cmd_result: CommandResult, output_format: OutputFormat | None = None):
159
169
  output_format = output_format or _get_format_type()
160
170
  if is_structured_format(output_format):
161
171
  print_structured(cmd_result)
162
- elif isinstance(cmd_result, MultipleResults):
172
+ elif isinstance(cmd_result, (MultipleResults, StreamResult)):
163
173
  for res in cmd_result.result:
164
174
  print_result(res)
165
175
  elif (
@@ -21,7 +21,12 @@ from typing import Dict, Optional
21
21
 
22
22
  import snowflake.connector
23
23
  from click.exceptions import ClickException
24
- from snowflake.cli.api.config import get_connection_dict, get_default_connection_dict
24
+ from snowflake.cli.api.cli_global_context import cli_context
25
+ from snowflake.cli.api.config import (
26
+ get_connection_dict,
27
+ get_default_connection_dict,
28
+ get_default_connection_name,
29
+ )
25
30
  from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB
26
31
  from snowflake.cli.api.exceptions import (
27
32
  InvalidConnectionConfiguration,
@@ -70,6 +75,9 @@ def connect_to_snowflake(
70
75
  connection_parameters = {} # we will apply overrides in next step
71
76
  else:
72
77
  connection_parameters = get_default_connection_dict()
78
+ cli_context.connection_context.set_connection_name(
79
+ get_default_connection_name()
80
+ )
73
81
 
74
82
  # Apply overrides to connection details
75
83
  for key, value in overrides.items():
@@ -22,6 +22,7 @@ from typing import Any, Dict, Union
22
22
  import click
23
23
  from snowflake.cli.__about__ import VERSION
24
24
  from snowflake.cli.api.cli_global_context import cli_context
25
+ from snowflake.cli.api.commands.execution_metadata import ExecutionMetadata
25
26
  from snowflake.cli.api.config import get_feature_flags_section
26
27
  from snowflake.cli.api.output.formats import OutputFormat
27
28
  from snowflake.cli.api.utils.error_handling import ignore_exceptions
@@ -44,19 +45,25 @@ class CLITelemetryField(Enum):
44
45
  COMMAND = "command"
45
46
  COMMAND_GROUP = "command_group"
46
47
  COMMAND_FLAGS = "command_flags"
48
+ COMMAND_EXECUTION_ID = "command_execution_id"
49
+ COMMAND_RESULT_STATUS = "command_result_status"
47
50
  COMMAND_OUTPUT_TYPE = "command_output_type"
51
+ COMMAND_EXECUTION_TIME = "command_execution_time"
48
52
  # Configuration
49
53
  CONFIG_FEATURE_FLAGS = "config_feature_flags"
50
54
  # Information
51
55
  EVENT = "event"
52
56
  ERROR_MSG = "error_msg"
53
57
  ERROR_TYPE = "error_type"
58
+ IS_CLI_EXCEPTION = "is_cli_exception"
54
59
  # Project context
55
60
  PROJECT_DEFINITION_VERSION = "project_definition_version"
56
61
 
57
62
 
58
63
  class TelemetryEvent(Enum):
59
64
  CMD_EXECUTION = "executing_command"
65
+ CMD_EXECUTION_ERROR = "error_executing_command"
66
+ CMD_EXECUTION_RESULT = "result_executing_command"
60
67
 
61
68
 
62
69
  TelemetryDict = Dict[Union[CLITelemetryField, TelemetryField], Any]
@@ -141,8 +148,40 @@ _telemetry = CLITelemetryClient(ctx=cli_context)
141
148
 
142
149
 
143
150
  @ignore_exceptions()
144
- def log_command_usage():
145
- _telemetry.send({TelemetryField.KEY_TYPE: TelemetryEvent.CMD_EXECUTION.value})
151
+ def log_command_usage(execution: ExecutionMetadata):
152
+ _telemetry.send(
153
+ {
154
+ TelemetryField.KEY_TYPE: TelemetryEvent.CMD_EXECUTION.value,
155
+ CLITelemetryField.COMMAND_EXECUTION_ID: execution.execution_id,
156
+ }
157
+ )
158
+
159
+
160
+ @ignore_exceptions()
161
+ def log_command_result(execution: ExecutionMetadata):
162
+ _telemetry.send(
163
+ {
164
+ TelemetryField.KEY_TYPE: TelemetryEvent.CMD_EXECUTION_RESULT.value,
165
+ CLITelemetryField.COMMAND_EXECUTION_ID: execution.execution_id,
166
+ CLITelemetryField.COMMAND_RESULT_STATUS: execution.status.value,
167
+ CLITelemetryField.COMMAND_EXECUTION_TIME: execution.get_duration(),
168
+ }
169
+ )
170
+
171
+
172
+ @ignore_exceptions()
173
+ def log_command_execution_error(exception: Exception, execution: ExecutionMetadata):
174
+ exception_type: str = type(exception).__name__
175
+ is_cli_exception: bool = issubclass(exception.__class__, click.ClickException)
176
+ _telemetry.send(
177
+ {
178
+ TelemetryField.KEY_TYPE: TelemetryEvent.CMD_EXECUTION_ERROR.value,
179
+ CLITelemetryField.COMMAND_EXECUTION_ID: execution.execution_id,
180
+ CLITelemetryField.ERROR_TYPE: exception_type,
181
+ CLITelemetryField.IS_CLI_EXCEPTION: is_cli_exception,
182
+ CLITelemetryField.COMMAND_EXECUTION_TIME: execution.get_duration(),
183
+ }
184
+ )
146
185
 
147
186
 
148
187
  @ignore_exceptions()
@@ -15,6 +15,7 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import logging
18
+ import os.path
18
19
 
19
20
  import typer
20
21
  from click import ClickException, Context, Parameter # type: ignore
@@ -218,6 +219,14 @@ def add(
218
219
  prompt="Path to private key file",
219
220
  help="Path to file containing private key",
220
221
  ),
222
+ token_file_path: str = typer.Option(
223
+ EmptyInput(),
224
+ "--token-file-path",
225
+ "-t",
226
+ click_type=OptionalPrompt(),
227
+ prompt="Path to token file",
228
+ help="Path to file with an OAuth token that should be used when connecting to Snowflake",
229
+ ),
221
230
  set_as_default: bool = typer.Option(
222
231
  False,
223
232
  "--default",
@@ -245,6 +254,7 @@ def add(
245
254
  role=role,
246
255
  authenticator=authenticator,
247
256
  private_key_path=private_key_path,
257
+ token_file_path=token_file_path,
248
258
  ),
249
259
  )
250
260
  if set_as_default:
@@ -300,9 +310,9 @@ def test(
300
310
  }
301
311
 
302
312
  if conn_ctx.enable_diag:
303
- result[
304
- "Diag Report Location"
305
- ] = f"{conn_ctx.diag_log_path}/SnowflakeConnectionTestReport.txt"
313
+ result["Diag Report Location"] = os.path.join(
314
+ conn_ctx.diag_log_path, "SnowflakeConnectionTestReport.txt"
315
+ )
306
316
 
307
317
  return ObjectResult(result)
308
318
 
@@ -12,14 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from __future__ import annotations
16
+
17
+ import json
15
18
  import logging
16
19
 
17
20
  from click.exceptions import ClickException
18
21
  from snowflake.connector import SnowflakeConnection
19
22
  from snowflake.connector.cursor import DictCursor
20
23
 
21
- LOCAL_DEPLOYMENT: str = "us-west-2"
22
-
23
24
  log = logging.getLogger(__name__)
24
25
 
25
26
  REGIONLESS_QUERY = """
@@ -29,11 +30,23 @@ REGIONLESS_QUERY = """
29
30
  )) where value['name'] = 'UI_SNOWSIGHT_ENABLE_REGIONLESS_REDIRECT';
30
31
  """
31
32
 
33
+ ALLOWLIST_QUERY = "SELECT SYSTEM$ALLOWLIST()"
34
+ SNOWFLAKE_DEPLOYMENT = "SNOWFLAKE_DEPLOYMENT"
35
+ LOCAL_DEPLOYMENT_REGION: str = "us-west-2"
32
36
 
33
- class MissingConnectionHostError(ClickException):
37
+
38
+ class MissingConnectionAccountError(ClickException):
34
39
  def __init__(self, conn: SnowflakeConnection):
35
40
  super().__init__(
36
- f"The connection host ({conn.host}) was missing or not in "
41
+ "Could not determine account by system call, configured account name, or configured host. Connection: "
42
+ + repr(conn)
43
+ )
44
+
45
+
46
+ class MissingConnectionRegionError(ClickException):
47
+ def __init__(self, host: str | None):
48
+ super().__init__(
49
+ f"The connection host ({host}) was missing or not in "
37
50
  "the expected format "
38
51
  "(<account>.<deployment>.snowflakecomputing.com)"
39
52
  )
@@ -50,11 +63,60 @@ def is_regionless_redirect(conn: SnowflakeConnection) -> bool:
50
63
  *_, cursor = conn.execute_string(REGIONLESS_QUERY, cursor_class=DictCursor)
51
64
  return cursor.fetchone()["REGIONLESS"].lower() == "true"
52
65
  except:
53
- # by default, assume that
54
- log.warning("Cannot determine regionless redirect; assuming True.")
66
+ log.warning(
67
+ "Cannot determine regionless redirect; assuming True.", exc_info=True
68
+ )
55
69
  return True
56
70
 
57
71
 
72
+ def get_host_region(host: str) -> str | None:
73
+ """
74
+ Looks for hosts of form
75
+ <account>.[x.y.z].snowflakecomputing.com
76
+ Returns the three-part [region identifier] or None.
77
+ """
78
+ host_parts = host.split(".")
79
+ if host_parts[-1] == "local":
80
+ return LOCAL_DEPLOYMENT_REGION
81
+ elif len(host_parts) == 6:
82
+ return ".".join(host_parts[1:4])
83
+ return None
84
+
85
+
86
+ def guess_regioned_host_from_allowlist(conn: SnowflakeConnection) -> str | None:
87
+ """
88
+ Use SYSTEM$ALLOWLIST to find a regioned host (<account>.x.y.z.snowflakecomputing.com)
89
+ that corresponds to the given Snowflake connection object.
90
+ """
91
+ try:
92
+ *_, cursor = conn.execute_string(ALLOWLIST_QUERY, cursor_class=DictCursor)
93
+ allowlist_tuples = json.loads(cursor.fetchone()["SYSTEM$ALLOWLIST()"])
94
+ for t in allowlist_tuples:
95
+ if t["type"] == SNOWFLAKE_DEPLOYMENT:
96
+ if get_host_region(t["host"]) is not None:
97
+ return t["host"]
98
+ except:
99
+ log.warning(
100
+ "Could not call SYSTEM$ALLOWLIST; returning an empty guess.", exc_info=True
101
+ )
102
+ return None
103
+
104
+
105
+ def get_region(conn: SnowflakeConnection) -> str:
106
+ """
107
+ Get the region of the given connection, or raise MissingConnectionRegionError.
108
+ """
109
+ if conn.host:
110
+ if region := get_host_region(conn.host):
111
+ return region
112
+
113
+ if host := guess_regioned_host_from_allowlist(conn):
114
+ if region := get_host_region(host):
115
+ return region
116
+
117
+ raise MissingConnectionRegionError(host or conn.host)
118
+
119
+
58
120
  def get_context(conn: SnowflakeConnection) -> str:
59
121
  """
60
122
  Determines the first part of the path in a Snowsight URL.
@@ -67,14 +129,7 @@ def get_context(conn: SnowflakeConnection) -> str:
67
129
  )
68
130
  return cursor.fetchone()["SYSTEM$RETURN_CURRENT_ORG_NAME()"]
69
131
 
70
- host_parts = conn.host.split(".")
71
- if host_parts[-1] == "local":
72
- return LOCAL_DEPLOYMENT
73
-
74
- if len(host_parts) == 6:
75
- return ".".join(host_parts[1:4])
76
-
77
- raise MissingConnectionHostError(conn)
132
+ return get_region(conn)
78
133
 
79
134
 
80
135
  def get_account(conn: SnowflakeConnection) -> str:
@@ -91,11 +146,11 @@ def get_account(conn: SnowflakeConnection) -> str:
91
146
  if conn.account:
92
147
  return conn.account
93
148
 
94
- if not conn.host:
95
- raise MissingConnectionHostError(conn)
149
+ if conn.host:
150
+ host_parts = conn.host.split(".")
151
+ return host_parts[0]
96
152
 
97
- host_parts = conn.host.split(".")
98
- return host_parts[0]
153
+ raise MissingConnectionAccountError(conn)
99
154
 
100
155
 
101
156
  def get_snowsight_host(conn: SnowflakeConnection) -> str:
@@ -24,6 +24,7 @@ from click import UsageError
24
24
  from snowflake.cli.api.cli_global_context import cli_context
25
25
  from snowflake.cli.api.commands.flags import readable_file_option
26
26
  from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
27
+ from snowflake.cli.api.constants import PYTHON_3_12
27
28
  from snowflake.cli.api.output.types import (
28
29
  CollectionResult,
29
30
  CommandResult,
@@ -45,7 +46,7 @@ app = SnowTyperFactory(
45
46
  help="Provides access to Snowflake Cortex.",
46
47
  )
47
48
 
48
- SEARCH_COMMAND_ENABLED = sys.version_info < (3, 12)
49
+ SEARCH_COMMAND_ENABLED = sys.version_info < PYTHON_3_12
49
50
 
50
51
 
51
52
  @app.command(
@@ -14,15 +14,18 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ import itertools
17
18
  import logging
19
+ from os import path
20
+ from pathlib import Path
18
21
  from typing import List, Optional
19
22
 
20
23
  import typer
21
24
  from click import ClickException
22
25
  from snowflake.cli.api.commands.flags import (
26
+ ExecuteVariablesOption,
23
27
  OnErrorOption,
24
28
  PatternOption,
25
- VariablesOption,
26
29
  identifier_argument,
27
30
  like_option,
28
31
  )
@@ -37,7 +40,6 @@ from snowflake.cli.plugins.object.command_aliases import (
37
40
  scope_option,
38
41
  )
39
42
  from snowflake.cli.plugins.object.manager import ObjectManager
40
- from snowflake.cli.plugins.stage.commands import get
41
43
  from snowflake.cli.plugins.stage.manager import OnErrorType
42
44
 
43
45
  app = SnowTyperFactory(
@@ -264,7 +266,6 @@ def copy(
264
266
  )
265
267
  )
266
268
  return get(
267
- recursive=True,
268
269
  source_path=repository_path,
269
270
  destination_path=destination_path,
270
271
  parallel=parallel,
@@ -275,7 +276,7 @@ def copy(
275
276
  def execute(
276
277
  repository_path: str = RepoPathArgument,
277
278
  on_error: OnErrorType = OnErrorOption,
278
- variables: Optional[List[str]] = VariablesOption,
279
+ variables: Optional[List[str]] = ExecuteVariablesOption,
279
280
  **options,
280
281
  ):
281
282
  """
@@ -287,3 +288,18 @@ def execute(
287
288
  stage_path=repository_path, on_error=on_error, variables=variables
288
289
  )
289
290
  return CollectionResult(results)
291
+
292
+
293
+ def get(source_path: str, destination_path: str, parallel: int):
294
+ target = Path(destination_path).resolve()
295
+
296
+ cursors = GitManager().get_recursive(
297
+ stage_path=source_path, dest_path=target, parallel=parallel
298
+ )
299
+ results = [list(QueryResult(c).result) for c in cursors]
300
+ flattened_results = list(itertools.chain.from_iterable(results))
301
+ sorted_results = sorted(
302
+ flattened_results,
303
+ key=lambda e: (path.dirname(e["file"]), path.basename(e["file"])),
304
+ )
305
+ return CollectionResult(sorted_results)