snowflake-cli-labs 2.6.1__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. snowflake/cli/__about__.py +1 -1
  2. snowflake/cli/api/cli_global_context.py +9 -0
  3. snowflake/cli/api/commands/decorators.py +9 -4
  4. snowflake/cli/api/commands/execution_metadata.py +40 -0
  5. snowflake/cli/api/commands/flags.py +45 -36
  6. snowflake/cli/api/commands/project_initialisation.py +5 -2
  7. snowflake/cli/api/commands/snow_typer.py +20 -9
  8. snowflake/cli/api/config.py +1 -0
  9. snowflake/cli/api/errno.py +27 -0
  10. snowflake/cli/api/feature_flags.py +5 -0
  11. snowflake/cli/api/identifiers.py +20 -3
  12. snowflake/cli/api/output/types.py +9 -0
  13. snowflake/cli/api/project/definition_manager.py +2 -2
  14. snowflake/cli/api/project/project_verification.py +23 -0
  15. snowflake/cli/api/project/schemas/entities/application_entity.py +50 -0
  16. snowflake/cli/api/project/schemas/entities/application_package_entity.py +63 -0
  17. snowflake/cli/api/project/schemas/entities/common.py +85 -0
  18. snowflake/cli/api/project/schemas/entities/entities.py +30 -0
  19. snowflake/cli/api/project/schemas/project_definition.py +114 -22
  20. snowflake/cli/api/project/schemas/streamlit/streamlit.py +5 -4
  21. snowflake/cli/api/project/schemas/template.py +77 -0
  22. snowflake/cli/{plugins/nativeapp/errno.py → api/rendering/__init__.py} +0 -2
  23. snowflake/cli/api/{utils/rendering.py → rendering/jinja.py} +3 -48
  24. snowflake/cli/api/rendering/project_definition_templates.py +39 -0
  25. snowflake/cli/api/rendering/project_templates.py +97 -0
  26. snowflake/cli/api/rendering/sql_templates.py +56 -0
  27. snowflake/cli/api/sql_execution.py +40 -1
  28. snowflake/cli/api/utils/definition_rendering.py +8 -5
  29. snowflake/cli/app/commands_registration/builtin_plugins.py +4 -0
  30. snowflake/cli/app/dev/docs/project_definition_docs_generator.py +2 -2
  31. snowflake/cli/app/loggers.py +3 -1
  32. snowflake/cli/app/printing.py +17 -7
  33. snowflake/cli/app/snow_connector.py +9 -1
  34. snowflake/cli/app/telemetry.py +41 -2
  35. snowflake/cli/plugins/connection/commands.py +13 -3
  36. snowflake/cli/plugins/connection/util.py +73 -18
  37. snowflake/cli/plugins/cortex/commands.py +2 -1
  38. snowflake/cli/plugins/git/commands.py +20 -4
  39. snowflake/cli/plugins/git/manager.py +44 -20
  40. snowflake/cli/plugins/init/__init__.py +13 -0
  41. snowflake/cli/plugins/init/commands.py +242 -0
  42. snowflake/cli/plugins/init/plugin_spec.py +30 -0
  43. snowflake/cli/plugins/nativeapp/codegen/artifact_processor.py +40 -0
  44. snowflake/cli/plugins/nativeapp/codegen/compiler.py +57 -27
  45. snowflake/cli/plugins/nativeapp/codegen/sandbox.py +99 -10
  46. snowflake/cli/plugins/nativeapp/codegen/setup/native_app_setup_processor.py +172 -0
  47. snowflake/cli/plugins/nativeapp/codegen/setup/setup_driver.py.source +56 -0
  48. snowflake/cli/plugins/nativeapp/codegen/snowpark/python_processor.py +21 -21
  49. snowflake/cli/plugins/nativeapp/commands.py +100 -6
  50. snowflake/cli/plugins/nativeapp/constants.py +0 -6
  51. snowflake/cli/plugins/nativeapp/exceptions.py +37 -12
  52. snowflake/cli/plugins/nativeapp/init.py +1 -1
  53. snowflake/cli/plugins/nativeapp/manager.py +114 -39
  54. snowflake/cli/plugins/nativeapp/project_model.py +8 -4
  55. snowflake/cli/plugins/nativeapp/run_processor.py +117 -102
  56. snowflake/cli/plugins/nativeapp/teardown_processor.py +7 -2
  57. snowflake/cli/plugins/nativeapp/v2_conversions/v2_to_v1_decorator.py +146 -0
  58. snowflake/cli/plugins/nativeapp/version/commands.py +19 -3
  59. snowflake/cli/plugins/nativeapp/version/version_processor.py +11 -3
  60. snowflake/cli/plugins/snowpark/commands.py +34 -26
  61. snowflake/cli/plugins/snowpark/common.py +88 -27
  62. snowflake/cli/plugins/snowpark/manager.py +16 -5
  63. snowflake/cli/plugins/snowpark/models.py +6 -0
  64. snowflake/cli/plugins/sql/commands.py +3 -5
  65. snowflake/cli/plugins/sql/manager.py +1 -1
  66. snowflake/cli/plugins/stage/commands.py +2 -2
  67. snowflake/cli/plugins/stage/diff.py +27 -64
  68. snowflake/cli/plugins/stage/manager.py +290 -86
  69. snowflake/cli/plugins/stage/md5.py +160 -0
  70. snowflake/cli/plugins/streamlit/commands.py +20 -6
  71. snowflake/cli/plugins/streamlit/manager.py +46 -32
  72. snowflake/cli/plugins/workspace/__init__.py +13 -0
  73. snowflake/cli/plugins/workspace/commands.py +35 -0
  74. snowflake/cli/plugins/workspace/plugin_spec.py +30 -0
  75. snowflake/cli/templates/default_snowpark/app/__init__.py +0 -13
  76. snowflake/cli/templates/default_snowpark/app/common.py +0 -15
  77. snowflake/cli/templates/default_snowpark/app/functions.py +0 -14
  78. snowflake/cli/templates/default_snowpark/app/procedures.py +0 -14
  79. snowflake/cli/templates/default_streamlit/common/hello.py +0 -15
  80. snowflake/cli/templates/default_streamlit/pages/my_page.py +0 -14
  81. snowflake/cli/templates/default_streamlit/streamlit_app.py +0 -14
  82. {snowflake_cli_labs-2.6.1.dist-info → snowflake_cli_labs-2.7.0.dist-info}/METADATA +7 -6
  83. {snowflake_cli_labs-2.6.1.dist-info → snowflake_cli_labs-2.7.0.dist-info}/RECORD +86 -65
  84. {snowflake_cli_labs-2.6.1.dist-info → snowflake_cli_labs-2.7.0.dist-info}/WHEEL +0 -0
  85. {snowflake_cli_labs-2.6.1.dist-info → snowflake_cli_labs-2.7.0.dist-info}/entry_points.txt +0 -0
  86. {snowflake_cli_labs-2.6.1.dist-info → snowflake_cli_labs-2.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -18,49 +18,155 @@ import fnmatch
18
18
  import glob
19
19
  import logging
20
20
  import re
21
+ import sys
21
22
  from contextlib import nullcontext
22
23
  from dataclasses import dataclass
23
24
  from os import path
24
25
  from pathlib import Path
26
+ from textwrap import dedent
25
27
  from typing import Dict, List, Optional, Union
26
28
 
27
29
  from click import ClickException
28
- from snowflake.cli.api.commands.flags import OnErrorType, parse_key_value_variables
30
+ from snowflake.cli.api.commands.flags import (
31
+ OnErrorType,
32
+ Variable,
33
+ parse_key_value_variables,
34
+ )
29
35
  from snowflake.cli.api.console import cli_console
36
+ from snowflake.cli.api.constants import PYTHON_3_12
37
+ from snowflake.cli.api.identifiers import FQN
30
38
  from snowflake.cli.api.project.util import to_string_literal
31
39
  from snowflake.cli.api.secure_path import SecurePath
32
40
  from snowflake.cli.api.sql_execution import SqlExecutionMixin
33
41
  from snowflake.cli.api.utils.path_utils import path_resolver
42
+ from snowflake.cli.plugins.snowpark.package_utils import parse_requirements
34
43
  from snowflake.connector import DictCursor, ProgrammingError
35
44
  from snowflake.connector.cursor import SnowflakeCursor
36
45
 
46
+ if sys.version_info < PYTHON_3_12:
47
+ # Because Snowpark works only below 3.12 and to use @sproc Session must be imported here.
48
+ from snowflake.snowpark import Session
49
+
37
50
  log = logging.getLogger(__name__)
38
51
 
39
52
 
40
53
  UNQUOTED_FILE_URI_REGEX = r"[\w/*?\-.=&{}$#[\]\"\\!@%^+:]+"
41
- EXECUTE_SUPPORTED_FILES_FORMATS = {".sql"}
54
+ USER_STAGE_PREFIX = "@~"
55
+ EXECUTE_SUPPORTED_FILES_FORMATS = (
56
+ ".sql",
57
+ ".py",
58
+ ) # tuple to preserve order but it's a set
42
59
 
43
60
 
44
61
  @dataclass
45
62
  class StagePathParts:
46
- # For path like @db.schema.stage/dir the values will be:
47
- # stage = @db.schema.stage
63
+ directory: str
48
64
  stage: str
49
- # stage_name = stage/dir
50
65
  stage_name: str
51
- # directory = dir
52
- directory: str
66
+ is_directory: bool
67
+
68
+ @staticmethod
69
+ def get_directory(stage_path: str) -> str:
70
+ return "/".join(Path(stage_path).parts[1:])
71
+
72
+ @property
73
+ def path(self) -> str:
74
+ raise NotImplementedError
75
+
76
+ def add_stage_prefix(self, file_path: str) -> str:
77
+ raise NotImplementedError
78
+
79
+ def get_directory_from_file_path(self, file_path: str) -> List[str]:
80
+ raise NotImplementedError
81
+
82
+ def get_full_stage_path(self, path: str):
83
+ if prefix := FQN.from_stage(self.stage).prefix:
84
+ return prefix + "." + path
85
+ return path
86
+
87
+ def get_standard_stage_path(self) -> str:
88
+ path = self.path
89
+ return f"@{path}{'/'if self.is_directory and not path.endswith('/') else ''}"
90
+
91
+ def get_standard_stage_directory_path(self) -> str:
92
+ path = self.get_standard_stage_path()
93
+ if not path.endswith("/"):
94
+ return path + "/"
95
+ return path
96
+
97
+
98
+ @dataclass
99
+ class DefaultStagePathParts(StagePathParts):
100
+ """
101
+ For path like @db.schema.stage/dir the values will be:
102
+ directory = dir
103
+ stage = @db.schema.stage
104
+ stage_name = stage
105
+ For `@stage/dir` to
106
+ stage -> @stage
107
+ stage_name -> stage
108
+ directory -> dir
109
+ """
110
+
111
+ def __init__(self, stage_path: str):
112
+ self.directory = self.get_directory(stage_path)
113
+ self.stage = StageManager.get_stage_from_path(stage_path)
114
+ stage_name = self.stage.split(".")[-1]
115
+ if stage_name.startswith("@"):
116
+ stage_name = stage_name[1:]
117
+ self.stage_name = stage_name
118
+ self.is_directory = True if stage_path.endswith("/") else False
53
119
 
54
120
  @property
55
121
  def path(self) -> str:
56
122
  return (
57
- f"{self.stage_name}{self.directory}".lower()
123
+ f"{self.stage_name}{self.directory}"
58
124
  if self.stage_name.endswith("/")
59
- else f"{self.stage_name}/{self.directory}".lower()
125
+ else f"{self.stage_name}/{self.directory}"
60
126
  )
61
127
 
128
+ def add_stage_prefix(self, file_path: str) -> str:
129
+ stage = Path(self.stage).parts[0]
130
+ file_path_without_prefix = Path(file_path).parts[1:]
131
+ return f"{stage}/{'/'.join(file_path_without_prefix)}"
132
+
133
+ def get_directory_from_file_path(self, file_path: str) -> List[str]:
134
+ stage_path_length = len(Path(self.directory).parts)
135
+ return list(Path(file_path).parts[1 + stage_path_length : -1])
136
+
137
+
138
+ @dataclass
139
+ class UserStagePathParts(StagePathParts):
140
+ """
141
+ For path like @db.schema.stage/dir the values will be:
142
+ directory = dir
143
+ stage = @~
144
+ stage_name = @~
145
+ """
146
+
147
+ def __init__(self, stage_path: str):
148
+ self.directory = self.get_directory(stage_path)
149
+ self.stage = "@~"
150
+ self.stage_name = "@~"
151
+ self.is_directory = True if stage_path.endswith("/") else False
152
+
153
+ @property
154
+ def path(self) -> str:
155
+ return f"{self.directory}"
156
+
157
+ def add_stage_prefix(self, file_path: str) -> str:
158
+ return f"{self.stage}/{file_path}"
159
+
160
+ def get_directory_from_file_path(self, file_path: str) -> List[str]:
161
+ stage_path_length = len(Path(self.directory).parts)
162
+ return list(Path(file_path).parts[stage_path_length:-1])
163
+
62
164
 
63
165
  class StageManager(SqlExecutionMixin):
166
+ def __init__(self):
167
+ super().__init__()
168
+ self._python_exe_procedure = None
169
+
64
170
  @staticmethod
65
171
  def get_standard_stage_prefix(name: str) -> str:
66
172
  # Handle embedded stages
@@ -69,12 +175,6 @@ class StageManager(SqlExecutionMixin):
69
175
 
70
176
  return f"@{name}"
71
177
 
72
- @staticmethod
73
- def get_standard_stage_directory_path(path):
74
- if not path.endswith("/"):
75
- path += "/"
76
- return StageManager.get_standard_stage_prefix(path)
77
-
78
178
  @staticmethod
79
179
  def get_stage_from_path(path: str):
80
180
  """
@@ -96,12 +196,6 @@ class StageManager(SqlExecutionMixin):
96
196
 
97
197
  return standard_name
98
198
 
99
- @staticmethod
100
- def remove_stage_prefix(stage_path: str) -> str:
101
- if stage_path.startswith("@"):
102
- return stage_path[1:]
103
- return stage_path
104
-
105
199
  def _to_uri(self, local_path: str):
106
200
  uri = f"file://{local_path}"
107
201
  if re.fullmatch(UNQUOTED_FILE_URI_REGEX, uri):
@@ -135,20 +229,17 @@ class StageManager(SqlExecutionMixin):
135
229
  def get_recursive(
136
230
  self, stage_path: str, dest_path: Path, parallel: int = 4
137
231
  ) -> List[SnowflakeCursor]:
138
- stage_path = self.get_standard_stage_prefix(stage_path)
139
- stage_parts_length = len(Path(stage_path).parts)
232
+ stage_path_parts = self._stage_path_part_factory(stage_path)
140
233
 
141
234
  results = []
142
- for file in self.iter_stage(stage_path):
143
- dest_directory = dest_path / "/".join(
144
- Path(file).parts[stage_parts_length:-1]
145
- )
146
- self._assure_is_existing_directory(Path(dest_directory))
147
-
148
- stage_path_with_prefix = self.get_standard_stage_prefix(file)
235
+ for file_path in self.iter_stage(stage_path):
236
+ dest_directory = dest_path
237
+ for path_part in stage_path_parts.get_directory_from_file_path(file_path):
238
+ dest_directory = dest_directory / path_part
239
+ self._assure_is_existing_directory(dest_directory)
149
240
 
150
241
  result = self._execute_query(
151
- f"get {self.quote_stage_name(stage_path_with_prefix)} {self._to_uri(f'{dest_directory}/')} parallel={parallel}"
242
+ f"get {self.quote_stage_name(stage_path_parts.add_stage_prefix(file_path))} {self._to_uri(f'{dest_directory}/')} parallel={parallel}"
152
243
  )
153
244
  results.append(result)
154
245
 
@@ -180,8 +271,16 @@ class StageManager(SqlExecutionMixin):
180
271
  return cursor
181
272
 
182
273
  def copy_files(self, source_path: str, destination_path: str) -> SnowflakeCursor:
183
- source = self.get_standard_stage_prefix(source_path)
184
- destination = self.get_standard_stage_directory_path(destination_path)
274
+ source_path_parts = self._stage_path_part_factory(source_path)
275
+ destination_path_parts = self._stage_path_part_factory(destination_path)
276
+
277
+ if isinstance(destination_path_parts, UserStagePathParts):
278
+ raise ClickException(
279
+ "Destination path cannot be a user stage. Please provide a named stage."
280
+ )
281
+
282
+ source = source_path_parts.get_standard_stage_path()
283
+ destination = destination_path_parts.get_standard_stage_directory_path()
185
284
  log.info("Copying files from %s to %s", source, destination)
186
285
  query = f"copy files into {destination} from {source}"
187
286
  return self._execute_query(query)
@@ -217,8 +316,7 @@ class StageManager(SqlExecutionMixin):
217
316
  on_error: OnErrorType,
218
317
  variables: Optional[List[str]] = None,
219
318
  ):
220
- stage_path_with_prefix = self.get_standard_stage_prefix(stage_path)
221
- stage_path_parts = self._split_stage_path(stage_path_with_prefix)
319
+ stage_path_parts = self._stage_path_part_factory(stage_path)
222
320
  all_files_list = self._get_files_list_from_stage(stage_path_parts)
223
321
 
224
322
  # filter files from stage if match stage_path pattern
@@ -228,44 +326,44 @@ class StageManager(SqlExecutionMixin):
228
326
  raise ClickException(f"No files matched pattern '{stage_path}'")
229
327
 
230
328
  # sort filtered files in alphabetical order with directories at the end
231
- sorted_file_list = sorted(
329
+ sorted_file_path_list = sorted(
232
330
  filtered_file_list, key=lambda f: (path.dirname(f), path.basename(f))
233
331
  )
234
332
 
235
- sql_variables = self._parse_execute_variables(variables)
333
+ parsed_variables = parse_key_value_variables(variables)
334
+ sql_variables = self._parse_execute_variables(parsed_variables)
335
+ python_variables = {str(v.key): v.value for v in parsed_variables}
236
336
  results = []
237
- for file in sorted_file_list:
238
- results.append(
239
- self._call_execute_immediate(
240
- stage_path_parts=stage_path_parts,
241
- file=file,
337
+
338
+ if any(file.endswith(".py") for file in sorted_file_path_list):
339
+ self._python_exe_procedure = self._bootstrap_snowpark_execution_environment(
340
+ stage_path_parts
341
+ )
342
+
343
+ for file_path in sorted_file_path_list:
344
+ file_stage_path = stage_path_parts.add_stage_prefix(file_path)
345
+ if file_path.endswith(".py"):
346
+ result = self._execute_python(
347
+ file_stage_path=file_stage_path,
348
+ on_error=on_error,
349
+ variables=python_variables,
350
+ )
351
+ else:
352
+ result = self._call_execute_immediate(
353
+ file_stage_path=file_stage_path,
242
354
  variables=sql_variables,
243
355
  on_error=on_error,
244
356
  )
245
- )
357
+ results.append(result)
246
358
 
247
359
  return results
248
360
 
249
- def _split_stage_path(self, stage_path: str) -> StagePathParts:
250
- """
251
- Splits stage path `@stage/dir` to
252
- stage -> @stage
253
- stage_name -> stage
254
- directory -> dir
255
- For stage path with fully qualified name `@db.schema.stage/dir`
256
- stage -> @db.schema.stage
257
- stage_name -> stage
258
- directory -> dir
259
- """
260
- stage = self.get_stage_from_path(stage_path)
261
- stage_name = stage.split(".")[-1]
262
- if stage_name.startswith("@"):
263
- stage_name = stage_name[1:]
264
- directory = "/".join(Path(stage_path).parts[1:])
265
- return StagePathParts(stage, stage_name, directory)
266
-
267
- def _get_files_list_from_stage(self, stage_path_parts: StagePathParts) -> List[str]:
268
- files_list_result = self.list_files(stage_path_parts.stage).fetchall()
361
+ def _get_files_list_from_stage(
362
+ self, stage_path_parts: StagePathParts, pattern: str | None = None
363
+ ) -> List[str]:
364
+ files_list_result = self.list_files(
365
+ stage_path_parts.stage, pattern=pattern
366
+ ).fetchall()
269
367
 
270
368
  if not files_list_result:
271
369
  raise ClickException(f"No files found on stage '{stage_path_parts.stage}'")
@@ -278,7 +376,7 @@ class StageManager(SqlExecutionMixin):
278
376
  if not stage_path_parts.directory:
279
377
  return self._filter_supported_files(files_on_stage)
280
378
 
281
- stage_path = stage_path_parts.path
379
+ stage_path = stage_path_parts.path.lower()
282
380
 
283
381
  # Exact file path was provided if stage_path in file list
284
382
  if stage_path in files_on_stage:
@@ -287,9 +385,8 @@ class StageManager(SqlExecutionMixin):
287
385
  return filtered_files
288
386
  else:
289
387
  raise ClickException(
290
- "Invalid file extension, only `.sql` files are allowed."
388
+ f"Invalid file extension, only {', '.join(EXECUTE_SUPPORTED_FILES_FORMATS)} files are allowed."
291
389
  )
292
-
293
390
  # Filter with fnmatch if contains `*` or `?`
294
391
  if glob.has_magic(stage_path):
295
392
  filtered_files = fnmatch.filter(files_on_stage, stage_path)
@@ -303,38 +400,145 @@ class StageManager(SqlExecutionMixin):
303
400
  return [f for f in files if Path(f).suffix in EXECUTE_SUPPORTED_FILES_FORMATS]
304
401
 
305
402
  @staticmethod
306
- def _parse_execute_variables(variables: Optional[List[str]]) -> Optional[str]:
403
+ def _parse_execute_variables(variables: List[Variable]) -> Optional[str]:
307
404
  if not variables:
308
405
  return None
309
-
310
- parsed_variables = parse_key_value_variables(variables)
311
- query_parameters = [f"{v.key}=>{v.value}" for v in parsed_variables]
406
+ query_parameters = [f"{v.key}=>{v.value}" for v in variables]
312
407
  return f" using ({', '.join(query_parameters)})"
313
408
 
409
+ @staticmethod
410
+ def _success_result(file: str):
411
+ cli_console.warning(f"SUCCESS - {file}")
412
+ return {"File": file, "Status": "SUCCESS", "Error": None}
413
+
414
+ @staticmethod
415
+ def _error_result(file: str, msg: str):
416
+ cli_console.warning(f"FAILURE - {file}")
417
+ return {"File": file, "Status": "FAILURE", "Error": msg}
418
+
419
+ @staticmethod
420
+ def _handle_execution_exception(on_error: OnErrorType, exception: Exception):
421
+ if on_error == OnErrorType.BREAK:
422
+ raise exception
423
+
314
424
  def _call_execute_immediate(
315
425
  self,
316
- stage_path_parts: StagePathParts,
317
- file: str,
426
+ file_stage_path: str,
318
427
  variables: Optional[str],
319
428
  on_error: OnErrorType,
320
429
  ) -> Dict:
321
- file_stage_path = self._build_file_stage_path(stage_path_parts, file)
322
430
  try:
323
431
  query = f"execute immediate from {file_stage_path}"
324
432
  if variables:
325
433
  query += variables
326
434
  self._execute_query(query)
327
- cli_console.step(f"SUCCESS - {file_stage_path}")
328
- return {"File": file_stage_path, "Status": "SUCCESS", "Error": None}
435
+ return StageManager._success_result(file=file_stage_path)
329
436
  except ProgrammingError as e:
330
- cli_console.warning(f"FAILURE - {file_stage_path}")
331
- if on_error == OnErrorType.BREAK:
332
- raise e
333
- return {"File": file_stage_path, "Status": "FAILURE", "Error": e.msg}
334
-
335
- def _build_file_stage_path(
336
- self, stage_path_parts: StagePathParts, file: str
337
- ) -> str:
338
- stage = Path(stage_path_parts.stage).parts[0]
339
- file_path = Path(file).parts[1:]
340
- return f"{stage}/{'/'.join(file_path)}"
437
+ StageManager._handle_execution_exception(on_error=on_error, exception=e)
438
+ return StageManager._error_result(file=file_stage_path, msg=e.msg)
439
+
440
+ @staticmethod
441
+ def _stage_path_part_factory(stage_path: str) -> StagePathParts:
442
+ stage_path = StageManager.get_standard_stage_prefix(stage_path)
443
+ if stage_path.startswith(USER_STAGE_PREFIX):
444
+ return UserStagePathParts(stage_path)
445
+ return DefaultStagePathParts(stage_path)
446
+
447
+ def _check_for_requirements_file(
448
+ self, stage_path_parts: StagePathParts
449
+ ) -> List[str]:
450
+ """Looks for requirements.txt file on stage."""
451
+ req_files_on_stage = self._get_files_list_from_stage(
452
+ stage_path_parts, pattern=r".*requirements\.txt$"
453
+ )
454
+ if not req_files_on_stage:
455
+ return []
456
+
457
+ # Construct all possible path for requirements file for this context
458
+ # We don't use os.path or pathlib to preserve compatibility on Windows
459
+ req_file_name = "requirements.txt"
460
+ path_parts = stage_path_parts.path.split("/")
461
+ possible_req_files = []
462
+
463
+ while path_parts:
464
+ current_file = "/".join([*path_parts, req_file_name])
465
+ possible_req_files.append(str(current_file))
466
+ path_parts = path_parts[:-1]
467
+
468
+ # Now for every possible path check if the file exists on stage,
469
+ # if yes break, we use the first possible file
470
+ requirements_file = None
471
+ for req_file in possible_req_files:
472
+ if req_file in req_files_on_stage:
473
+ requirements_file = req_file
474
+ break
475
+
476
+ # If we haven't found any matching requirements
477
+ if requirements_file is None:
478
+ return []
479
+
480
+ # req_file at this moment is the first found requirements file
481
+ with SecurePath.temporary_directory() as tmp_dir:
482
+ self.get(
483
+ stage_path_parts.get_full_stage_path(requirements_file), tmp_dir.path
484
+ )
485
+ requirements = parse_requirements(
486
+ requirements_file=tmp_dir / "requirements.txt"
487
+ )
488
+
489
+ return [req.package_name for req in requirements]
490
+
491
+ def _bootstrap_snowpark_execution_environment(
492
+ self, stage_path_parts: StagePathParts
493
+ ):
494
+ """Prepares Snowpark session for executing Python code remotely."""
495
+ if sys.version_info >= PYTHON_3_12:
496
+ raise ClickException(
497
+ f"Executing python files is not supported in Python >= 3.12. Current version: {sys.version}"
498
+ )
499
+
500
+ from snowflake.snowpark.functions import sproc
501
+
502
+ self.snowpark_session.add_packages("snowflake-snowpark-python")
503
+ self.snowpark_session.add_packages("snowflake.core")
504
+ requirements = self._check_for_requirements_file(stage_path_parts)
505
+ self.snowpark_session.add_packages(*requirements)
506
+
507
+ @sproc(is_permanent=False)
508
+ def _python_execution_procedure(
509
+ _: Session, file_path: str, variables: Dict | None = None
510
+ ) -> None:
511
+ """Snowpark session-scoped stored procedure to execute content of provided python file."""
512
+ import json
513
+
514
+ from snowflake.snowpark.files import SnowflakeFile
515
+
516
+ with SnowflakeFile.open(file_path, require_scoped_url=False) as f:
517
+ file_content: str = f.read() # type: ignore
518
+
519
+ wrapper = dedent(
520
+ f"""\
521
+ import os
522
+ os.environ.update({json.dumps(variables)})
523
+ """
524
+ )
525
+
526
+ exec(wrapper + file_content)
527
+
528
+ return _python_execution_procedure
529
+
530
+ def _execute_python(
531
+ self, file_stage_path: str, on_error: OnErrorType, variables: Dict
532
+ ):
533
+ """
534
+ Executes Python file from stage using a Snowpark temporary procedure.
535
+ Currently, there's no option to pass input to the execution.
536
+ """
537
+ from snowflake.snowpark.exceptions import SnowparkSQLException
538
+
539
+ try:
540
+ self._python_exe_procedure(self.get_standard_stage_prefix(file_stage_path), variables) # type: ignore
541
+ return StageManager._success_result(file=file_stage_path)
542
+ except SnowparkSQLException as e:
543
+ StageManager._handle_execution_exception(on_error=on_error, exception=e)
544
+ return StageManager._error_result(file=file_stage_path, msg=e.message)
@@ -0,0 +1,160 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import hashlib
18
+ import logging
19
+ import math
20
+ import os.path
21
+ import re
22
+ from pathlib import Path
23
+ from typing import List, Tuple
24
+
25
+ from click.exceptions import ClickException
26
+ from snowflake.cli.api.secure_path import UNLIMITED, SecurePath
27
+ from snowflake.connector.constants import S3_CHUNK_SIZE, S3_MAX_PARTS, S3_MIN_PART_SIZE
28
+
29
+ ONE_MEGABYTE = 1024**2
30
+ READ_BUFFER_BYTES = 64 * 1024
31
+ MD5SUM_REGEX = r"^[A-Fa-f0-9]{32}$"
32
+ MULTIPART_MD5SUM_REGEX = r"^([A-Fa-f0-9]{32})-(\d+)$"
33
+
34
+ log = logging.getLogger(__name__)
35
+
36
+
37
+ class UnknownMD5FormatError(ClickException):
38
+ def __init__(self, md5: str):
39
+ super().__init__(f"Unknown md5 format: {md5}")
40
+
41
+
42
+ def is_md5sum(checksum: str) -> bool:
43
+ """
44
+ Could the provided hexadecimal checksum represent a valid md5sum?
45
+ """
46
+ return re.match(MD5SUM_REGEX, checksum) is not None
47
+
48
+
49
+ def parse_multipart_md5sum(checksum: str) -> Tuple[str, int] | None:
50
+ """
51
+ Does this represent a multi-part md5sum (i.e. "<md5>-<n>")?
52
+ If so, returns the tuple (md5, n), otherwise None.
53
+ """
54
+ multipart_md5 = re.match(MULTIPART_MD5SUM_REGEX, checksum)
55
+ if multipart_md5:
56
+ return (multipart_md5.group(1), int(multipart_md5.group(2)))
57
+ return None
58
+
59
+
60
+ def compute_md5sum(file: Path, chunk_size: int | None = None) -> str:
61
+ """
62
+ Returns a hexadecimal checksum for the file located at the given path.
63
+ If chunk_size is given, computes a multi-part md5sum.
64
+ """
65
+ if not file.is_file():
66
+ raise ValueError(
67
+ "The provided file does not exist or not a (symlink to a) regular file"
68
+ )
69
+
70
+ # If the stage uses SNOWFLAKE_FULL encryption, this will fail to provide
71
+ # a matching md5sum, even when the underlying file is the same, as we do
72
+ # not have access to the encrypted file under checksum.
73
+
74
+ file_size = os.path.getsize(file)
75
+ if file_size == 0:
76
+ # simple md5 with no content
77
+ return hashlib.md5().hexdigest()
78
+
79
+ with SecurePath(file).open("rb", read_file_limit_mb=UNLIMITED) as f:
80
+ md5s: List[hashlib._Hash] = [] # noqa: SLF001
81
+ hasher = hashlib.md5()
82
+
83
+ remains = file_size
84
+ remains_in_chunk: int = min(chunk_size, remains) if chunk_size else remains
85
+ while remains > 0:
86
+ sz = min(READ_BUFFER_BYTES, remains_in_chunk)
87
+ buf = f.read(sz)
88
+ hasher.update(buf)
89
+ remains_in_chunk -= sz
90
+ remains -= sz
91
+ if remains_in_chunk == 0:
92
+ if not chunk_size:
93
+ # simple md5; only one chunk processed
94
+ return hasher.hexdigest()
95
+ else:
96
+ # push the hash of this chunk + reset
97
+ md5s.append(hasher)
98
+ hasher = hashlib.md5()
99
+ remains_in_chunk = min(chunk_size, remains)
100
+
101
+ # multi-part hash (e.g. aws)
102
+ digests = b"".join(m.digest() for m in md5s)
103
+ digests_md5 = hashlib.md5(digests)
104
+ return f"{digests_md5.hexdigest()}-{len(md5s)}"
105
+
106
+
107
+ def file_matches_md5sum(local_file: Path, remote_md5: str | None) -> bool:
108
+ """
109
+ Try a few different md5sums to determine if a local file is identical
110
+ to a file that has a given remote md5sum.
111
+
112
+ Handles the multi-part md5sums generated by e.g. AWS S3, using values
113
+ from the python connector to make educated guesses on chunk size.
114
+
115
+ Assumes that upload time would dominate local hashing time.
116
+ """
117
+ if not remote_md5:
118
+ # no hash available
119
+ return False
120
+
121
+ if is_md5sum(remote_md5):
122
+ # regular hash
123
+ return compute_md5sum(local_file) == remote_md5
124
+
125
+ if md5_and_chunks := parse_multipart_md5sum(remote_md5):
126
+ # multi-part hash (e.g. aws)
127
+ (_, num_chunks) = md5_and_chunks
128
+ file_size = os.path.getsize(local_file)
129
+
130
+ # If this file uses the maximum number of parts supported by the cloud backend,
131
+ # the chunk size is likely not a clean multiple of a megabyte. Try reverse engineering
132
+ # from the file size first, then fall back to the usual detection method.
133
+ # At time of writing this logic would trigger for files >= 80GiB (python connector)
134
+ if num_chunks == S3_MAX_PARTS:
135
+ chunk_size = max(math.ceil(file_size / S3_MAX_PARTS), S3_MIN_PART_SIZE)
136
+ if compute_md5sum(local_file, chunk_size) == remote_md5:
137
+ return True
138
+
139
+ # Estimates the chunk size the multi-part file must have been uploaded with
140
+ # by trying chunk sizes that give the most evenly-sized chunks.
141
+ #
142
+ # First we'll try the chunk size that's a multiple of S3_CHUNK_SIZE (8mb) from
143
+ # the python connector that results in num_chunks, then we'll do the same with
144
+ # a smaller granularity (1mb) that is used by default in some AWS multi-part
145
+ # upload implementations.
146
+ #
147
+ # We're working backwards from num_chunks here because it's the only value we know.
148
+ for chunk_size_alignment in [S3_CHUNK_SIZE, ONE_MEGABYTE]:
149
+ # +1 because we need at least one chunk when file_size < num_chunks * chunk_size_alignment
150
+ # -1 because we don't want to add an extra chunk when file_size is an exact multiple of num_chunks * chunk_size_alignment
151
+ multiplier = 1 + ((file_size - 1) // (num_chunks * chunk_size_alignment))
152
+ chunk_size = multiplier * chunk_size_alignment
153
+ if compute_md5sum(local_file, chunk_size) == remote_md5:
154
+ return True
155
+
156
+ # we were unable to figure out the chunk size, or the files are different
157
+ log.debug("multi-part md5: %s != %s", remote_md5, local_file)
158
+ return False
159
+
160
+ raise UnknownMD5FormatError(remote_md5)