snowflake-cli-labs 2.3.0rc1__py3-none-any.whl → 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. snowflake/cli/__about__.py +1 -1
  2. snowflake/cli/api/__init__.py +2 -0
  3. snowflake/cli/api/cli_global_context.py +8 -1
  4. snowflake/cli/api/commands/decorators.py +2 -2
  5. snowflake/cli/api/commands/flags.py +49 -4
  6. snowflake/cli/api/commands/snow_typer.py +2 -0
  7. snowflake/cli/api/console/abc.py +2 -0
  8. snowflake/cli/api/console/console.py +6 -5
  9. snowflake/cli/api/constants.py +5 -0
  10. snowflake/cli/api/exceptions.py +12 -0
  11. snowflake/cli/api/identifiers.py +123 -0
  12. snowflake/cli/api/plugins/command/__init__.py +2 -0
  13. snowflake/cli/api/plugins/plugin_config.py +2 -0
  14. snowflake/cli/api/project/definition.py +2 -0
  15. snowflake/cli/api/project/errors.py +3 -3
  16. snowflake/cli/api/project/schemas/identifier_model.py +35 -0
  17. snowflake/cli/api/project/schemas/native_app/native_app.py +4 -0
  18. snowflake/cli/api/project/schemas/native_app/path_mapping.py +21 -3
  19. snowflake/cli/api/project/schemas/project_definition.py +58 -6
  20. snowflake/cli/api/project/schemas/snowpark/argument.py +2 -0
  21. snowflake/cli/api/project/schemas/snowpark/callable.py +8 -17
  22. snowflake/cli/api/project/schemas/streamlit/streamlit.py +2 -2
  23. snowflake/cli/api/project/schemas/updatable_model.py +2 -0
  24. snowflake/cli/api/project/util.py +2 -0
  25. snowflake/cli/api/secure_path.py +2 -0
  26. snowflake/cli/api/sql_execution.py +14 -54
  27. snowflake/cli/api/utils/cursor.py +2 -0
  28. snowflake/cli/api/utils/models.py +23 -0
  29. snowflake/cli/api/utils/naming_utils.py +0 -27
  30. snowflake/cli/api/utils/rendering.py +178 -23
  31. snowflake/cli/app/api_impl/plugin/plugin_config_provider_impl.py +2 -0
  32. snowflake/cli/app/cli_app.py +4 -1
  33. snowflake/cli/app/commands_registration/builtin_plugins.py +8 -0
  34. snowflake/cli/app/commands_registration/command_plugins_loader.py +2 -0
  35. snowflake/cli/app/commands_registration/commands_registration_with_callbacks.py +2 -0
  36. snowflake/cli/app/commands_registration/typer_registration.py +2 -0
  37. snowflake/cli/app/dev/pycharm_remote_debug.py +2 -0
  38. snowflake/cli/app/loggers.py +2 -0
  39. snowflake/cli/app/main_typer.py +1 -1
  40. snowflake/cli/app/printing.py +3 -1
  41. snowflake/cli/app/snow_connector.py +2 -2
  42. snowflake/cli/plugins/connection/commands.py +5 -14
  43. snowflake/cli/plugins/connection/util.py +1 -1
  44. snowflake/cli/plugins/cortex/__init__.py +0 -0
  45. snowflake/cli/plugins/cortex/commands.py +312 -0
  46. snowflake/cli/plugins/cortex/constants.py +3 -0
  47. snowflake/cli/plugins/cortex/manager.py +175 -0
  48. snowflake/cli/plugins/cortex/plugin_spec.py +16 -0
  49. snowflake/cli/plugins/cortex/types.py +8 -0
  50. snowflake/cli/plugins/git/commands.py +15 -0
  51. snowflake/cli/plugins/nativeapp/artifacts.py +368 -123
  52. snowflake/cli/plugins/nativeapp/codegen/artifact_processor.py +45 -0
  53. snowflake/cli/plugins/nativeapp/codegen/compiler.py +104 -0
  54. snowflake/cli/plugins/nativeapp/codegen/sandbox.py +2 -0
  55. snowflake/cli/plugins/nativeapp/codegen/snowpark/callback_source.py.jinja +181 -0
  56. snowflake/cli/plugins/nativeapp/codegen/snowpark/extension_function_utils.py +196 -0
  57. snowflake/cli/plugins/nativeapp/codegen/snowpark/models.py +47 -0
  58. snowflake/cli/plugins/nativeapp/codegen/snowpark/python_processor.py +489 -0
  59. snowflake/cli/plugins/nativeapp/commands.py +11 -4
  60. snowflake/cli/plugins/nativeapp/common_flags.py +12 -5
  61. snowflake/cli/plugins/nativeapp/constants.py +1 -0
  62. snowflake/cli/plugins/nativeapp/manager.py +49 -16
  63. snowflake/cli/plugins/nativeapp/policy.py +2 -0
  64. snowflake/cli/plugins/nativeapp/run_processor.py +28 -10
  65. snowflake/cli/plugins/nativeapp/teardown_processor.py +80 -8
  66. snowflake/cli/plugins/nativeapp/utils.py +7 -6
  67. snowflake/cli/plugins/nativeapp/version/commands.py +6 -5
  68. snowflake/cli/plugins/nativeapp/version/version_processor.py +2 -0
  69. snowflake/cli/plugins/notebook/commands.py +21 -0
  70. snowflake/cli/plugins/notebook/exceptions.py +6 -0
  71. snowflake/cli/plugins/notebook/manager.py +46 -3
  72. snowflake/cli/plugins/notebook/types.py +2 -0
  73. snowflake/cli/plugins/object/command_aliases.py +80 -0
  74. snowflake/cli/plugins/object/commands.py +10 -6
  75. snowflake/cli/plugins/object/common.py +2 -0
  76. snowflake/cli/plugins/object_stage_deprecated/__init__.py +1 -0
  77. snowflake/cli/plugins/object_stage_deprecated/plugin_spec.py +20 -0
  78. snowflake/cli/plugins/snowpark/commands.py +62 -6
  79. snowflake/cli/plugins/snowpark/common.py +17 -6
  80. snowflake/cli/plugins/spcs/compute_pool/commands.py +22 -1
  81. snowflake/cli/plugins/spcs/compute_pool/manager.py +2 -0
  82. snowflake/cli/plugins/spcs/image_repository/commands.py +25 -1
  83. snowflake/cli/plugins/spcs/image_repository/manager.py +3 -1
  84. snowflake/cli/plugins/spcs/services/commands.py +39 -5
  85. snowflake/cli/plugins/spcs/services/manager.py +2 -0
  86. snowflake/cli/plugins/sql/commands.py +13 -5
  87. snowflake/cli/plugins/sql/manager.py +40 -19
  88. snowflake/cli/plugins/stage/commands.py +29 -3
  89. snowflake/cli/plugins/stage/diff.py +2 -0
  90. snowflake/cli/plugins/streamlit/commands.py +26 -10
  91. snowflake/cli/plugins/streamlit/manager.py +9 -10
  92. {snowflake_cli_labs-2.3.0rc1.dist-info → snowflake_cli_labs-2.4.0.dist-info}/METADATA +4 -2
  93. {snowflake_cli_labs-2.3.0rc1.dist-info → snowflake_cli_labs-2.4.0.dist-info}/RECORD +97 -77
  94. /snowflake/cli/plugins/{object/stage_deprecated → object_stage_deprecated}/commands.py +0 -0
  95. {snowflake_cli_labs-2.3.0rc1.dist-info → snowflake_cli_labs-2.4.0.dist-info}/WHEEL +0 -0
  96. {snowflake_cli_labs-2.3.0rc1.dist-info → snowflake_cli_labs-2.4.0.dist-info}/entry_points.txt +0 -0
  97. {snowflake_cli_labs-2.3.0rc1.dist-info → snowflake_cli_labs-2.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,104 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Dict, Optional
5
+
6
+ from snowflake.cli.api.console import cli_console as cc
7
+ from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp
8
+ from snowflake.cli.api.project.schemas.native_app.path_mapping import (
9
+ PathMapping,
10
+ ProcessorMapping,
11
+ )
12
+ from snowflake.cli.plugins.nativeapp.artifacts import resolve_without_follow
13
+ from snowflake.cli.plugins.nativeapp.codegen.artifact_processor import (
14
+ ArtifactProcessor,
15
+ UnsupportedArtifactProcessorError,
16
+ )
17
+ from snowflake.cli.plugins.nativeapp.codegen.snowpark.python_processor import (
18
+ SnowparkAnnotationProcessor,
19
+ )
20
+
21
+ SNOWPARK_PROCESSOR = "snowpark"
22
+
23
+
24
+ class NativeAppCompiler:
25
+ """
26
+ Compiler class to perform custom processing on all relevant Native Apps artifacts (specified in the project definition file)
27
+ before an application package can be created from those artifacts.
28
+ An artifact can have more than one processor specified for itself, and this class will execute those processors in that order.
29
+ The class also maintains a dictionary of processors it creates in order to reuse them across artifacts, since processor initialization
30
+ is independent of the artifact to process.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ project_definition: NativeApp,
36
+ project_root: Path,
37
+ deploy_root: Path,
38
+ generated_root: Path,
39
+ ):
40
+ self.project_definition = project_definition
41
+ self.project_root = project_root
42
+ self.deploy_root = deploy_root
43
+ self.generated_root = generated_root
44
+
45
+ self.artifacts = [
46
+ artifact
47
+ for artifact in project_definition.artifacts
48
+ if isinstance(artifact, PathMapping)
49
+ ]
50
+ # dictionary of all processors created and shared between different artifact objects.
51
+ self.cached_processors: Dict[str, ArtifactProcessor] = {}
52
+
53
+ def compile_artifacts(self):
54
+ """
55
+ Go through every artifact object in the project definition of a native app, and execute processors in order of specification for each of the artifact object.
56
+ May have side-effects on the filesystem by either directly editing source files or the deploy root.
57
+ """
58
+ should_proceed = False
59
+ for artifact in self.artifacts:
60
+ if artifact.processors:
61
+ should_proceed = True
62
+ break
63
+ if not should_proceed:
64
+ return
65
+
66
+ with cc.phase("Invoking artifact processors"):
67
+ for artifact in self.artifacts:
68
+ for processor in artifact.processors:
69
+ artifact_processor = self._try_create_processor(
70
+ processor_mapping=processor,
71
+ )
72
+ if artifact_processor is None:
73
+ raise UnsupportedArtifactProcessorError(
74
+ processor_name=processor.name
75
+ )
76
+ else:
77
+ artifact_processor.process(
78
+ artifact_to_process=artifact, processor_mapping=processor
79
+ )
80
+
81
+ def _try_create_processor(
82
+ self,
83
+ processor_mapping: ProcessorMapping,
84
+ **kwargs,
85
+ ) -> Optional[ArtifactProcessor]:
86
+ """
87
+ Fetch processor object if one already exists in the cached_processors dictionary.
88
+ Else, initialize a new object to return, and add it to the cached_processors dictionary.
89
+ """
90
+ if processor_mapping.name.lower() == SNOWPARK_PROCESSOR:
91
+ curr_processor = self.cached_processors.get(SNOWPARK_PROCESSOR, None)
92
+ if curr_processor is not None:
93
+ return curr_processor
94
+ else:
95
+ curr_processor = SnowparkAnnotationProcessor(
96
+ project_definition=self.project_definition,
97
+ project_root=resolve_without_follow(self.project_root),
98
+ deploy_root=resolve_without_follow(self.deploy_root),
99
+ generated_root=resolve_without_follow(self.generated_root),
100
+ )
101
+ self.cached_processors[SNOWPARK_PROCESSOR] = curr_processor
102
+ return curr_processor
103
+ else:
104
+ return None
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import shutil
3
5
  import subprocess
@@ -0,0 +1,181 @@
1
+ import contextlib
2
+ import functools
3
+ import inspect
4
+ import sys
5
+ from typing import Callable
6
+
7
+ try:
8
+ import snowflake.snowpark
9
+ except ModuleNotFoundError as exc:
10
+ print(
11
+ "An exception occurred while importing snowflake-snowpark-python package: ",
12
+ exc,
13
+ file=sys.stderr,
14
+ )
15
+ sys.exit(1)
16
+
17
+ __snowflake_internal_found_correct_version = hasattr(
18
+ snowflake.snowpark.context, "_is_execution_environment_sandboxed_for_client"
19
+ ) and hasattr(snowflake.snowpark.context, "_should_continue_registration")
20
+
21
+ if not __snowflake_internal_found_correct_version:
22
+ print(
23
+ "Did not find the minimum required version for snowflake-snowpark-python package. Please upgrade to v1.15.0 or higher.",
24
+ file=sys.stderr,
25
+ )
26
+ sys.exit(1)
27
+
28
+ __snowflake_global_collected_extension_fn_json = []
29
+
30
+ def __snowflake_internal_create_extension_fn_registration_callback():
31
+ def __snowflake_internal_try_extract_lineno(extension_fn):
32
+ try:
33
+ return inspect.getsourcelines(extension_fn)[1]
34
+ except Exception:
35
+ return None
36
+
37
+ def __snowflake_internal_extract_extension_fn_name(extension_fn):
38
+ try:
39
+ import snowflake.snowpark._internal.utils as snowpark_utils
40
+
41
+ if hasattr(snowpark_utils, 'TEMP_OBJECT_NAME_PREFIX'):
42
+ if extension_fn.object_name.startswith(snowpark_utils.TEMP_OBJECT_NAME_PREFIX):
43
+ # The object name is a generated one, don't use it
44
+ return None
45
+
46
+ except Exception:
47
+ # ignore any exception and fall back to using the object name reported from Snowpark
48
+ pass
49
+
50
+ return extension_fn.object_name
51
+
52
+ def __snowflake_internal_create_package_list(extension_fn):
53
+ if not extension_fn.all_packages.strip():
54
+ return []
55
+ return [pkg_spec.strip() for pkg_spec in extension_fn.all_packages.split(",")]
56
+
57
+ def __snowflake_internal_make_extension_fn_signature(extension_fn):
58
+ # Try to fetch the original argument names from the extension function
59
+ try:
60
+ args_spec = inspect.getfullargspec(extension_fn.func)
61
+ original_arg_names = args_spec[0]
62
+ start_index = len(original_arg_names) - len(extension_fn.input_sql_types)
63
+ signature = []
64
+ defaults_start_index = len(original_arg_names) - len(args_spec.defaults or [])
65
+ for i in range(len(extension_fn.input_sql_types)):
66
+ arg = {
67
+ 'name': original_arg_names[start_index + i],
68
+ 'type': extension_fn.input_sql_types[i]
69
+ }
70
+ if i >= defaults_start_index:
71
+ arg['default'] = args_spec.defaults[defaults_start_index + i]
72
+ signature.append(arg)
73
+
74
+ return signature
75
+ except Exception as e:
76
+ msg = str(e)
77
+ pass # ignore, we'll use the fallback strategy
78
+
79
+ # Failed to extract the original arguments through reflection, fall back to alternative approach
80
+ return [
81
+ {"name": input_arg.name, "type": input_type}
82
+ for (input_arg, input_type) in zip(extension_fn.input_args, extension_fn.input_sql_types)
83
+ ]
84
+
85
+ def __snowflake_internal_to_extension_fn_type(object_type):
86
+ if object_type.name == "AGGREGATE_FUNCTION":
87
+ return "aggregate function"
88
+ if object_type.name == "TABLE_FUNCTION":
89
+ return "table function"
90
+ return object_type.name.lower()
91
+
92
+ def __snowflake_internal_imports_union_to_str_type(raw_imports):
93
+ final_imports = []
94
+ if raw_imports:
95
+ for raw_import in raw_imports:
96
+ if isinstance(raw_import, str):
97
+ final_imports.append(raw_import)
98
+ else:
99
+ final_imports.append(raw_import[0])
100
+ return final_imports
101
+
102
+
103
+ def __snowflake_internal_extension_fn_to_json(extension_fn):
104
+ if not isinstance(extension_fn.func, Callable):
105
+ # Unsupported case: extension function is a tuple
106
+ return
107
+
108
+ if extension_fn.anonymous:
109
+ # unsupported, native application extension functions need to be explicitly named
110
+ return
111
+
112
+ # Collect basic properties of the extension function
113
+ extension_fn_json = {
114
+ "type": __snowflake_internal_to_extension_fn_type(extension_fn.object_type),
115
+ "lineno": __snowflake_internal_try_extract_lineno(extension_fn.func),
116
+ "name": __snowflake_internal_extract_extension_fn_name(extension_fn),
117
+ "handler": extension_fn.func.__name__,
118
+ "imports": __snowflake_internal_imports_union_to_str_type(extension_fn.raw_imports),
119
+ "packages": __snowflake_internal_create_package_list(extension_fn),
120
+ "runtime": extension_fn.runtime_version,
121
+ "returns": extension_fn.return_sql.upper().replace("RETURNS ", "").strip(),
122
+ "signature": __snowflake_internal_make_extension_fn_signature(extension_fn),
123
+ "external_access_integrations": extension_fn.external_access_integrations or [],
124
+ "secrets": extension_fn.secrets or {},
125
+ }
126
+
127
+ if extension_fn.object_type.name == "PROCEDURE" and extension_fn.execute_as is not None:
128
+ extension_fn_json['execute_as_caller'] = (extension_fn.execute_as == 'caller')
129
+
130
+ if extension_fn.native_app_params is not None:
131
+ schema = extension_fn.native_app_params.get("schema")
132
+ if schema is not None:
133
+ extension_fn_json["schema"] = schema
134
+ app_roles = extension_fn.native_app_params.get("application_roles")
135
+ if app_roles is not None:
136
+ extension_fn_json["application_roles"] = app_roles
137
+
138
+ return extension_fn_json
139
+
140
+ def __snowflake_internal_collect_extension_fn(
141
+ collected_extension_fn_json_list, extension_function_properties
142
+ ):
143
+ extension_fn_json = __snowflake_internal_extension_fn_to_json(extension_function_properties)
144
+ collected_extension_fn_json_list.append(extension_fn_json)
145
+ return False
146
+
147
+ return functools.partial(
148
+ __snowflake_internal_collect_extension_fn,
149
+ __snowflake_global_collected_extension_fn_json,
150
+ )
151
+
152
+ snowflake.snowpark.context._is_execution_environment_sandboxed_for_client = ( # noqa: SLF001
153
+ True
154
+ )
155
+ snowflake.snowpark.context._should_continue_registration = ( # noqa: SLF001
156
+ __snowflake_internal_create_extension_fn_registration_callback()
157
+ )
158
+ snowflake.snowpark.session._is_execution_environment_sandboxed_for_client = ( # noqa: SLF001
159
+ True
160
+ )
161
+
162
+ for global_key in list(globals().keys()):
163
+ if global_key.startswith("__snowflake_internal"):
164
+ del globals()[global_key]
165
+
166
+ del globals()["global_key"] # make sure to clean up the loop variable as well
167
+
168
+ try:
169
+ import importlib
170
+ with contextlib.redirect_stdout(None):
171
+ with contextlib.redirect_stderr(None):
172
+ __snowflake_internal_spec = importlib.util.spec_from_file_location("<string>", "{{py_file}}")
173
+ __snowflake_internal_module = importlib.util.module_from_spec(__snowflake_internal_spec)
174
+ __snowflake_internal_spec.loader.exec_module(__snowflake_internal_module)
175
+ except Exception as exc: # Catch any error
176
+ print("An exception occurred while executing file: ", exc, file=sys.stderr)
177
+ sys.exit(1)
178
+
179
+
180
+ import json
181
+ print(json.dumps(__snowflake_global_collected_extension_fn_json))
@@ -0,0 +1,196 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ from typing import (
5
+ Any,
6
+ List,
7
+ Optional,
8
+ Sequence,
9
+ )
10
+
11
+ from click.exceptions import ClickException
12
+ from snowflake.cli.api.project.schemas.snowpark.argument import Argument
13
+ from snowflake.cli.api.project.util import (
14
+ is_valid_identifier,
15
+ is_valid_string_literal,
16
+ to_identifier,
17
+ to_string_literal,
18
+ )
19
+ from snowflake.cli.plugins.nativeapp.codegen.snowpark.models import (
20
+ ExtensionFunctionTypeEnum,
21
+ NativeAppExtensionFunction,
22
+ )
23
+
24
+
25
+ class MalformedExtensionFunctionError(ClickException):
26
+ """Required extension function attribute is missing."""
27
+
28
+ def __init__(self, message: str):
29
+ super().__init__(message=message)
30
+
31
+
32
+ def get_sql_object_type(extension_fn: NativeAppExtensionFunction) -> Optional[str]:
33
+ if extension_fn.function_type == ExtensionFunctionTypeEnum.PROCEDURE:
34
+ return "PROCEDURE"
35
+ elif extension_fn.function_type in (
36
+ ExtensionFunctionTypeEnum.FUNCTION,
37
+ ExtensionFunctionTypeEnum.TABLE_FUNCTION,
38
+ ):
39
+ return "FUNCTION"
40
+ elif extension_fn.function_type == extension_fn.function_type.AGGREGATE_FUNCTION:
41
+ return "AGGREGATE FUNCTION"
42
+ else:
43
+ return None
44
+
45
+
46
+ def get_sql_argument_signature(arg: Argument) -> str:
47
+ formatted = f"{arg.name} {arg.arg_type}"
48
+ if arg.default is not None:
49
+ formatted = f"{formatted} DEFAULT {arg.default}"
50
+ return formatted
51
+
52
+
53
+ def get_function_type_signature_for_grant(
54
+ extension_fn: NativeAppExtensionFunction,
55
+ ) -> str:
56
+ """
57
+ Returns the type signature for the specified function, e.g. "int, varchar", suitable for inclusion in a GRANT statement.
58
+ """
59
+ return ", ".join([arg.arg_type for arg in extension_fn.signature])
60
+
61
+
62
+ def get_qualified_object_name(extension_fn: NativeAppExtensionFunction) -> str:
63
+ qualified_name = to_identifier(extension_fn.name)
64
+ if extension_fn.schema_name:
65
+ if is_valid_identifier(extension_fn.schema_name):
66
+ qualified_name = f"{extension_fn.schema_name}.{qualified_name}"
67
+ else:
68
+ full_schema = ".".join(
69
+ [
70
+ to_identifier(schema_part)
71
+ for schema_part in extension_fn.schema_name.split(".")
72
+ ]
73
+ )
74
+ qualified_name = f"{full_schema}.{qualified_name}"
75
+
76
+ return qualified_name
77
+
78
+
79
+ def ensure_string_literal(value: str) -> str:
80
+ """
81
+ Returns the string literal representation of the given value, or the value itself if
82
+ it was already a valid string literal.
83
+ """
84
+ if is_valid_string_literal(value):
85
+ return value
86
+ return to_string_literal(value)
87
+
88
+
89
+ def ensure_all_string_literals(values: Sequence[str]) -> List[str]:
90
+ """
91
+ Ensures that all provided values are valid string literals.
92
+
93
+ Returns:
94
+ A list with all values transformed to be valid string literals (as necessary).
95
+ """
96
+ return [ensure_string_literal(value) for value in values]
97
+
98
+
99
+ class _FunctionDefAccumulator(ast.NodeVisitor):
100
+ """
101
+ A NodeVisitor that collects AST nodes corresponding to function declarations, filtered by a list
102
+ of wanted functions. This is used to identify all Snowpark extension functions in a module's
103
+ source code.
104
+ """
105
+
106
+ def __init__(self, functions: Sequence[NativeAppExtensionFunction]):
107
+ self._wanted_functions_by_name = {
108
+ fn.handler.split(".")[-1]: fn for fn in functions
109
+ }
110
+ self.definitions: List[Any] = []
111
+
112
+ def visit_FunctionDef(self, node: ast.FunctionDef): # noqa: N802
113
+ if self._want(node):
114
+ self.definitions.append(node)
115
+ self.generic_visit(node)
116
+
117
+ def _want(self, node: Any) -> bool:
118
+ if not node.decorator_list:
119
+ # No decorators for this definition, ignore it
120
+ return False
121
+
122
+ return node.name in self._wanted_functions_by_name
123
+
124
+
125
+ def _get_decorator_id(node: ast.AST) -> Optional[str]:
126
+ """
127
+ Returns the fully qualified identifier for a decorator, e.g. "foo" or "foo.bar".
128
+ """
129
+ if isinstance(node, ast.Name):
130
+ return node.id
131
+ elif isinstance(node, ast.Attribute):
132
+ return f"{_get_decorator_id(node.value)}.{node.attr}"
133
+ elif isinstance(node, ast.Call):
134
+ return _get_decorator_id(node.func)
135
+ else:
136
+ return None
137
+
138
+
139
+ def _collect_ast_function_definitions(
140
+ tree: ast.AST, extension_functions: Sequence[NativeAppExtensionFunction]
141
+ ) -> Sequence[ast.FunctionDef]:
142
+ accumulator = _FunctionDefAccumulator(extension_functions)
143
+ accumulator.visit(tree)
144
+ return accumulator.definitions
145
+
146
+
147
+ def deannotate_module_source(
148
+ module_source: str,
149
+ extension_functions: Sequence[NativeAppExtensionFunction],
150
+ annotations_to_preserve: Sequence[str] = (),
151
+ ) -> str:
152
+ """
153
+ Removes annotations from a set of specified extension functions.
154
+
155
+ Arguments:
156
+ module_source (str): The source code of the module to deannotate.
157
+ extension_functions (Sequence[NativeAppExtensionFunction]): The list of extension functions
158
+ to deannotate. Other functions encountered will be ignored.
159
+ annotations_to_preserve (Sequence[str], optional): The list of annotations to preserve. The
160
+ names should appear as they are found in the source code, e.g. "foo" for @foo or
161
+ "annotations.bar" for @annotations.bar.
162
+
163
+ Returns:
164
+ A de-annotated version of the module source if any match was found. In order to preserve
165
+ line numbers, annotations are simply commented out instead of completely removed.
166
+ """
167
+
168
+ tree = ast.parse(module_source)
169
+
170
+ definitions = _collect_ast_function_definitions(tree, extension_functions)
171
+ if not definitions:
172
+ return module_source
173
+
174
+ module_lines = module_source.splitlines()
175
+ for definition in definitions:
176
+ # Comment out all decorators. As per the python grammar, decorators must be terminated by a
177
+ # new line, so the line ranges can't overlap.
178
+ for decorator in definition.decorator_list:
179
+ decorator_id = _get_decorator_id(decorator)
180
+ if decorator_id is None:
181
+ continue
182
+ if annotations_to_preserve and decorator_id in annotations_to_preserve:
183
+ continue
184
+
185
+ # AST indices are 1-based
186
+ start_lineno = decorator.lineno - 1
187
+ if decorator.end_lineno is not None:
188
+ end_lineno = decorator.end_lineno - 1
189
+ else:
190
+ end_lineno = start_lineno
191
+
192
+ for lineno in range(start_lineno, end_lineno + 1):
193
+ module_lines[lineno] = "#: " + module_lines[lineno]
194
+
195
+ # we're writing files in text mode, so we should use '\n' regardless of the platform
196
+ return "\n".join(module_lines)
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import List, Optional
5
+
6
+ from pydantic import Field
7
+ from snowflake.cli.api.project.schemas.snowpark.callable import _CallableBase
8
+ from snowflake.cli.api.project.schemas.updatable_model import IdentifierField
9
+
10
+
11
+ class ExtensionFunctionTypeEnum(str, Enum):
12
+ PROCEDURE = "procedure"
13
+ FUNCTION = "function"
14
+ TABLE_FUNCTION = "table function"
15
+ AGGREGATE_FUNCTION = "aggregate function"
16
+
17
+
18
+ class NativeAppExtensionFunction(_CallableBase):
19
+ function_type: ExtensionFunctionTypeEnum = Field(
20
+ title="The type of extension function, one of 'procedure', 'function', 'table function' or 'aggregate function'",
21
+ alias="type",
22
+ )
23
+ lineno: Optional[int] = Field(
24
+ title="The starting line number of the extension function (1-based)",
25
+ default=None,
26
+ )
27
+ name: Optional[str] = Field(
28
+ title="The name of the extension function", default=None
29
+ )
30
+ packages: Optional[List[str]] = Field(
31
+ title="List of packages (with optional version constraints) to be loaded for the function",
32
+ default=[],
33
+ )
34
+ schema_name: Optional[str] = IdentifierField(
35
+ title=f"Name of the schema for the function",
36
+ default=None,
37
+ alias="schema",
38
+ )
39
+ application_roles: Optional[List[str]] = Field(
40
+ title="Application roles granted usage to the function",
41
+ default=[],
42
+ )
43
+ execute_as_caller: Optional[bool] = Field(
44
+ title="Determine whether the procedure is executed with the privileges of "
45
+ "the owner or with the privileges of the caller",
46
+ default=False,
47
+ )