snowflake-cli-labs 2.3.0rc1__py3-none-any.whl → 2.4.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cli/__about__.py +1 -1
- snowflake/cli/api/__init__.py +2 -0
- snowflake/cli/api/cli_global_context.py +8 -1
- snowflake/cli/api/commands/decorators.py +2 -2
- snowflake/cli/api/commands/flags.py +49 -4
- snowflake/cli/api/commands/snow_typer.py +2 -0
- snowflake/cli/api/console/abc.py +2 -0
- snowflake/cli/api/console/console.py +6 -5
- snowflake/cli/api/constants.py +5 -0
- snowflake/cli/api/exceptions.py +12 -0
- snowflake/cli/api/identifiers.py +123 -0
- snowflake/cli/api/plugins/command/__init__.py +2 -0
- snowflake/cli/api/plugins/plugin_config.py +2 -0
- snowflake/cli/api/project/definition.py +2 -0
- snowflake/cli/api/project/errors.py +3 -3
- snowflake/cli/api/project/schemas/identifier_model.py +35 -0
- snowflake/cli/api/project/schemas/native_app/native_app.py +4 -0
- snowflake/cli/api/project/schemas/native_app/path_mapping.py +21 -3
- snowflake/cli/api/project/schemas/project_definition.py +58 -6
- snowflake/cli/api/project/schemas/snowpark/argument.py +2 -0
- snowflake/cli/api/project/schemas/snowpark/callable.py +8 -17
- snowflake/cli/api/project/schemas/streamlit/streamlit.py +2 -2
- snowflake/cli/api/project/schemas/updatable_model.py +2 -0
- snowflake/cli/api/project/util.py +2 -0
- snowflake/cli/api/secure_path.py +2 -0
- snowflake/cli/api/sql_execution.py +14 -54
- snowflake/cli/api/utils/cursor.py +2 -0
- snowflake/cli/api/utils/models.py +23 -0
- snowflake/cli/api/utils/naming_utils.py +0 -27
- snowflake/cli/api/utils/rendering.py +178 -23
- snowflake/cli/app/api_impl/plugin/plugin_config_provider_impl.py +2 -0
- snowflake/cli/app/cli_app.py +4 -1
- snowflake/cli/app/commands_registration/builtin_plugins.py +8 -0
- snowflake/cli/app/commands_registration/command_plugins_loader.py +2 -0
- snowflake/cli/app/commands_registration/commands_registration_with_callbacks.py +2 -0
- snowflake/cli/app/commands_registration/typer_registration.py +2 -0
- snowflake/cli/app/dev/pycharm_remote_debug.py +2 -0
- snowflake/cli/app/loggers.py +2 -0
- snowflake/cli/app/main_typer.py +1 -1
- snowflake/cli/app/printing.py +3 -1
- snowflake/cli/app/snow_connector.py +2 -2
- snowflake/cli/plugins/connection/commands.py +5 -14
- snowflake/cli/plugins/connection/util.py +1 -1
- snowflake/cli/plugins/cortex/__init__.py +0 -0
- snowflake/cli/plugins/cortex/commands.py +312 -0
- snowflake/cli/plugins/cortex/constants.py +3 -0
- snowflake/cli/plugins/cortex/manager.py +175 -0
- snowflake/cli/plugins/cortex/plugin_spec.py +16 -0
- snowflake/cli/plugins/cortex/types.py +8 -0
- snowflake/cli/plugins/git/commands.py +15 -0
- snowflake/cli/plugins/nativeapp/artifacts.py +368 -123
- snowflake/cli/plugins/nativeapp/codegen/artifact_processor.py +45 -0
- snowflake/cli/plugins/nativeapp/codegen/compiler.py +104 -0
- snowflake/cli/plugins/nativeapp/codegen/sandbox.py +2 -0
- snowflake/cli/plugins/nativeapp/codegen/snowpark/callback_source.py.jinja +181 -0
- snowflake/cli/plugins/nativeapp/codegen/snowpark/extension_function_utils.py +196 -0
- snowflake/cli/plugins/nativeapp/codegen/snowpark/models.py +47 -0
- snowflake/cli/plugins/nativeapp/codegen/snowpark/python_processor.py +489 -0
- snowflake/cli/plugins/nativeapp/commands.py +11 -4
- snowflake/cli/plugins/nativeapp/common_flags.py +12 -5
- snowflake/cli/plugins/nativeapp/constants.py +1 -0
- snowflake/cli/plugins/nativeapp/manager.py +49 -16
- snowflake/cli/plugins/nativeapp/policy.py +2 -0
- snowflake/cli/plugins/nativeapp/run_processor.py +28 -10
- snowflake/cli/plugins/nativeapp/teardown_processor.py +78 -8
- snowflake/cli/plugins/nativeapp/utils.py +7 -6
- snowflake/cli/plugins/nativeapp/version/commands.py +6 -5
- snowflake/cli/plugins/nativeapp/version/version_processor.py +2 -0
- snowflake/cli/plugins/notebook/commands.py +21 -0
- snowflake/cli/plugins/notebook/exceptions.py +6 -0
- snowflake/cli/plugins/notebook/manager.py +46 -3
- snowflake/cli/plugins/notebook/types.py +2 -0
- snowflake/cli/plugins/object/command_aliases.py +80 -0
- snowflake/cli/plugins/object/commands.py +10 -6
- snowflake/cli/plugins/object/common.py +2 -0
- snowflake/cli/plugins/object_stage_deprecated/__init__.py +1 -0
- snowflake/cli/plugins/object_stage_deprecated/plugin_spec.py +20 -0
- snowflake/cli/plugins/snowpark/commands.py +62 -6
- snowflake/cli/plugins/snowpark/common.py +17 -6
- snowflake/cli/plugins/spcs/compute_pool/commands.py +22 -1
- snowflake/cli/plugins/spcs/compute_pool/manager.py +2 -0
- snowflake/cli/plugins/spcs/image_repository/commands.py +25 -1
- snowflake/cli/plugins/spcs/image_repository/manager.py +3 -1
- snowflake/cli/plugins/spcs/services/commands.py +39 -5
- snowflake/cli/plugins/spcs/services/manager.py +2 -0
- snowflake/cli/plugins/sql/commands.py +13 -5
- snowflake/cli/plugins/sql/manager.py +40 -19
- snowflake/cli/plugins/stage/commands.py +29 -3
- snowflake/cli/plugins/stage/diff.py +2 -0
- snowflake/cli/plugins/streamlit/commands.py +26 -10
- snowflake/cli/plugins/streamlit/manager.py +9 -10
- {snowflake_cli_labs-2.3.0rc1.dist-info → snowflake_cli_labs-2.4.0rc0.dist-info}/METADATA +4 -2
- {snowflake_cli_labs-2.3.0rc1.dist-info → snowflake_cli_labs-2.4.0rc0.dist-info}/RECORD +97 -77
- /snowflake/cli/plugins/{object/stage_deprecated → object_stage_deprecated}/commands.py +0 -0
- {snowflake_cli_labs-2.3.0rc1.dist-info → snowflake_cli_labs-2.4.0rc0.dist-info}/WHEEL +0 -0
- {snowflake_cli_labs-2.3.0rc1.dist-info → snowflake_cli_labs-2.4.0rc0.dist-info}/entry_points.txt +0 -0
- {snowflake_cli_labs-2.3.0rc1.dist-info → snowflake_cli_labs-2.4.0rc0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Dict, Optional
|
|
5
|
+
|
|
6
|
+
from snowflake.cli.api.console import cli_console as cc
|
|
7
|
+
from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp
|
|
8
|
+
from snowflake.cli.api.project.schemas.native_app.path_mapping import (
|
|
9
|
+
PathMapping,
|
|
10
|
+
ProcessorMapping,
|
|
11
|
+
)
|
|
12
|
+
from snowflake.cli.plugins.nativeapp.artifacts import resolve_without_follow
|
|
13
|
+
from snowflake.cli.plugins.nativeapp.codegen.artifact_processor import (
|
|
14
|
+
ArtifactProcessor,
|
|
15
|
+
UnsupportedArtifactProcessorError,
|
|
16
|
+
)
|
|
17
|
+
from snowflake.cli.plugins.nativeapp.codegen.snowpark.python_processor import (
|
|
18
|
+
SnowparkAnnotationProcessor,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
SNOWPARK_PROCESSOR = "snowpark"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NativeAppCompiler:
|
|
25
|
+
"""
|
|
26
|
+
Compiler class to perform custom processing on all relevant Native Apps artifacts (specified in the project definition file)
|
|
27
|
+
before an application package can be created from those artifacts.
|
|
28
|
+
An artifact can have more than one processor specified for itself, and this class will execute those processors in that order.
|
|
29
|
+
The class also maintains a dictionary of processors it creates in order to reuse them across artifacts, since processor initialization
|
|
30
|
+
is independent of the artifact to process.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
project_definition: NativeApp,
|
|
36
|
+
project_root: Path,
|
|
37
|
+
deploy_root: Path,
|
|
38
|
+
generated_root: Path,
|
|
39
|
+
):
|
|
40
|
+
self.project_definition = project_definition
|
|
41
|
+
self.project_root = project_root
|
|
42
|
+
self.deploy_root = deploy_root
|
|
43
|
+
self.generated_root = generated_root
|
|
44
|
+
|
|
45
|
+
self.artifacts = [
|
|
46
|
+
artifact
|
|
47
|
+
for artifact in project_definition.artifacts
|
|
48
|
+
if isinstance(artifact, PathMapping)
|
|
49
|
+
]
|
|
50
|
+
# dictionary of all processors created and shared between different artifact objects.
|
|
51
|
+
self.cached_processors: Dict[str, ArtifactProcessor] = {}
|
|
52
|
+
|
|
53
|
+
def compile_artifacts(self):
|
|
54
|
+
"""
|
|
55
|
+
Go through every artifact object in the project definition of a native app, and execute processors in order of specification for each of the artifact object.
|
|
56
|
+
May have side-effects on the filesystem by either directly editing source files or the deploy root.
|
|
57
|
+
"""
|
|
58
|
+
should_proceed = False
|
|
59
|
+
for artifact in self.artifacts:
|
|
60
|
+
if artifact.processors:
|
|
61
|
+
should_proceed = True
|
|
62
|
+
break
|
|
63
|
+
if not should_proceed:
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
with cc.phase("Invoking artifact processors"):
|
|
67
|
+
for artifact in self.artifacts:
|
|
68
|
+
for processor in artifact.processors:
|
|
69
|
+
artifact_processor = self._try_create_processor(
|
|
70
|
+
processor_mapping=processor,
|
|
71
|
+
)
|
|
72
|
+
if artifact_processor is None:
|
|
73
|
+
raise UnsupportedArtifactProcessorError(
|
|
74
|
+
processor_name=processor.name
|
|
75
|
+
)
|
|
76
|
+
else:
|
|
77
|
+
artifact_processor.process(
|
|
78
|
+
artifact_to_process=artifact, processor_mapping=processor
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def _try_create_processor(
|
|
82
|
+
self,
|
|
83
|
+
processor_mapping: ProcessorMapping,
|
|
84
|
+
**kwargs,
|
|
85
|
+
) -> Optional[ArtifactProcessor]:
|
|
86
|
+
"""
|
|
87
|
+
Fetch processor object if one already exists in the cached_processors dictionary.
|
|
88
|
+
Else, initialize a new object to return, and add it to the cached_processors dictionary.
|
|
89
|
+
"""
|
|
90
|
+
if processor_mapping.name.lower() == SNOWPARK_PROCESSOR:
|
|
91
|
+
curr_processor = self.cached_processors.get(SNOWPARK_PROCESSOR, None)
|
|
92
|
+
if curr_processor is not None:
|
|
93
|
+
return curr_processor
|
|
94
|
+
else:
|
|
95
|
+
curr_processor = SnowparkAnnotationProcessor(
|
|
96
|
+
project_definition=self.project_definition,
|
|
97
|
+
project_root=resolve_without_follow(self.project_root),
|
|
98
|
+
deploy_root=resolve_without_follow(self.deploy_root),
|
|
99
|
+
generated_root=resolve_without_follow(self.generated_root),
|
|
100
|
+
)
|
|
101
|
+
self.cached_processors[SNOWPARK_PROCESSOR] = curr_processor
|
|
102
|
+
return curr_processor
|
|
103
|
+
else:
|
|
104
|
+
return None
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import functools
|
|
3
|
+
import inspect
|
|
4
|
+
import sys
|
|
5
|
+
from typing import Callable
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import snowflake.snowpark
|
|
9
|
+
except ModuleNotFoundError as exc:
|
|
10
|
+
print(
|
|
11
|
+
"An exception occurred while importing snowflake-snowpark-python package: ",
|
|
12
|
+
exc,
|
|
13
|
+
file=sys.stderr,
|
|
14
|
+
)
|
|
15
|
+
sys.exit(1)
|
|
16
|
+
|
|
17
|
+
__snowflake_internal_found_correct_version = hasattr(
|
|
18
|
+
snowflake.snowpark.context, "_is_execution_environment_sandboxed_for_client"
|
|
19
|
+
) and hasattr(snowflake.snowpark.context, "_should_continue_registration")
|
|
20
|
+
|
|
21
|
+
if not __snowflake_internal_found_correct_version:
|
|
22
|
+
print(
|
|
23
|
+
"Did not find the minimum required version for snowflake-snowpark-python package. Please upgrade to v1.15.0 or higher.",
|
|
24
|
+
file=sys.stderr,
|
|
25
|
+
)
|
|
26
|
+
sys.exit(1)
|
|
27
|
+
|
|
28
|
+
__snowflake_global_collected_extension_fn_json = []
|
|
29
|
+
|
|
30
|
+
def __snowflake_internal_create_extension_fn_registration_callback():
|
|
31
|
+
def __snowflake_internal_try_extract_lineno(extension_fn):
|
|
32
|
+
try:
|
|
33
|
+
return inspect.getsourcelines(extension_fn)[1]
|
|
34
|
+
except Exception:
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
def __snowflake_internal_extract_extension_fn_name(extension_fn):
|
|
38
|
+
try:
|
|
39
|
+
import snowflake.snowpark._internal.utils as snowpark_utils
|
|
40
|
+
|
|
41
|
+
if hasattr(snowpark_utils, 'TEMP_OBJECT_NAME_PREFIX'):
|
|
42
|
+
if extension_fn.object_name.startswith(snowpark_utils.TEMP_OBJECT_NAME_PREFIX):
|
|
43
|
+
# The object name is a generated one, don't use it
|
|
44
|
+
return None
|
|
45
|
+
|
|
46
|
+
except Exception:
|
|
47
|
+
# ignore any exception and fall back to using the object name reported from Snowpark
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
return extension_fn.object_name
|
|
51
|
+
|
|
52
|
+
def __snowflake_internal_create_package_list(extension_fn):
|
|
53
|
+
if not extension_fn.all_packages.strip():
|
|
54
|
+
return []
|
|
55
|
+
return [pkg_spec.strip() for pkg_spec in extension_fn.all_packages.split(",")]
|
|
56
|
+
|
|
57
|
+
def __snowflake_internal_make_extension_fn_signature(extension_fn):
|
|
58
|
+
# Try to fetch the original argument names from the extension function
|
|
59
|
+
try:
|
|
60
|
+
args_spec = inspect.getfullargspec(extension_fn.func)
|
|
61
|
+
original_arg_names = args_spec[0]
|
|
62
|
+
start_index = len(original_arg_names) - len(extension_fn.input_sql_types)
|
|
63
|
+
signature = []
|
|
64
|
+
defaults_start_index = len(original_arg_names) - len(args_spec.defaults or [])
|
|
65
|
+
for i in range(len(extension_fn.input_sql_types)):
|
|
66
|
+
arg = {
|
|
67
|
+
'name': original_arg_names[start_index + i],
|
|
68
|
+
'type': extension_fn.input_sql_types[i]
|
|
69
|
+
}
|
|
70
|
+
if i >= defaults_start_index:
|
|
71
|
+
arg['default'] = args_spec.defaults[defaults_start_index + i]
|
|
72
|
+
signature.append(arg)
|
|
73
|
+
|
|
74
|
+
return signature
|
|
75
|
+
except Exception as e:
|
|
76
|
+
msg = str(e)
|
|
77
|
+
pass # ignore, we'll use the fallback strategy
|
|
78
|
+
|
|
79
|
+
# Failed to extract the original arguments through reflection, fall back to alternative approach
|
|
80
|
+
return [
|
|
81
|
+
{"name": input_arg.name, "type": input_type}
|
|
82
|
+
for (input_arg, input_type) in zip(extension_fn.input_args, extension_fn.input_sql_types)
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
def __snowflake_internal_to_extension_fn_type(object_type):
|
|
86
|
+
if object_type.name == "AGGREGATE_FUNCTION":
|
|
87
|
+
return "aggregate function"
|
|
88
|
+
if object_type.name == "TABLE_FUNCTION":
|
|
89
|
+
return "table function"
|
|
90
|
+
return object_type.name.lower()
|
|
91
|
+
|
|
92
|
+
def __snowflake_internal_imports_union_to_str_type(raw_imports):
|
|
93
|
+
final_imports = []
|
|
94
|
+
if raw_imports:
|
|
95
|
+
for raw_import in raw_imports:
|
|
96
|
+
if isinstance(raw_import, str):
|
|
97
|
+
final_imports.append(raw_import)
|
|
98
|
+
else:
|
|
99
|
+
final_imports.append(raw_import[0])
|
|
100
|
+
return final_imports
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def __snowflake_internal_extension_fn_to_json(extension_fn):
|
|
104
|
+
if not isinstance(extension_fn.func, Callable):
|
|
105
|
+
# Unsupported case: extension function is a tuple
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
if extension_fn.anonymous:
|
|
109
|
+
# unsupported, native application extension functions need to be explicitly named
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
# Collect basic properties of the extension function
|
|
113
|
+
extension_fn_json = {
|
|
114
|
+
"type": __snowflake_internal_to_extension_fn_type(extension_fn.object_type),
|
|
115
|
+
"lineno": __snowflake_internal_try_extract_lineno(extension_fn.func),
|
|
116
|
+
"name": __snowflake_internal_extract_extension_fn_name(extension_fn),
|
|
117
|
+
"handler": extension_fn.func.__name__,
|
|
118
|
+
"imports": __snowflake_internal_imports_union_to_str_type(extension_fn.raw_imports),
|
|
119
|
+
"packages": __snowflake_internal_create_package_list(extension_fn),
|
|
120
|
+
"runtime": extension_fn.runtime_version,
|
|
121
|
+
"returns": extension_fn.return_sql.upper().replace("RETURNS ", "").strip(),
|
|
122
|
+
"signature": __snowflake_internal_make_extension_fn_signature(extension_fn),
|
|
123
|
+
"external_access_integrations": extension_fn.external_access_integrations or [],
|
|
124
|
+
"secrets": extension_fn.secrets or {},
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
if extension_fn.object_type.name == "PROCEDURE" and extension_fn.execute_as is not None:
|
|
128
|
+
extension_fn_json['execute_as_caller'] = (extension_fn.execute_as == 'caller')
|
|
129
|
+
|
|
130
|
+
if extension_fn.native_app_params is not None:
|
|
131
|
+
schema = extension_fn.native_app_params.get("schema")
|
|
132
|
+
if schema is not None:
|
|
133
|
+
extension_fn_json["schema"] = schema
|
|
134
|
+
app_roles = extension_fn.native_app_params.get("application_roles")
|
|
135
|
+
if app_roles is not None:
|
|
136
|
+
extension_fn_json["application_roles"] = app_roles
|
|
137
|
+
|
|
138
|
+
return extension_fn_json
|
|
139
|
+
|
|
140
|
+
def __snowflake_internal_collect_extension_fn(
|
|
141
|
+
collected_extension_fn_json_list, extension_function_properties
|
|
142
|
+
):
|
|
143
|
+
extension_fn_json = __snowflake_internal_extension_fn_to_json(extension_function_properties)
|
|
144
|
+
collected_extension_fn_json_list.append(extension_fn_json)
|
|
145
|
+
return False
|
|
146
|
+
|
|
147
|
+
return functools.partial(
|
|
148
|
+
__snowflake_internal_collect_extension_fn,
|
|
149
|
+
__snowflake_global_collected_extension_fn_json,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
snowflake.snowpark.context._is_execution_environment_sandboxed_for_client = ( # noqa: SLF001
|
|
153
|
+
True
|
|
154
|
+
)
|
|
155
|
+
snowflake.snowpark.context._should_continue_registration = ( # noqa: SLF001
|
|
156
|
+
__snowflake_internal_create_extension_fn_registration_callback()
|
|
157
|
+
)
|
|
158
|
+
snowflake.snowpark.session._is_execution_environment_sandboxed_for_client = ( # noqa: SLF001
|
|
159
|
+
True
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
for global_key in list(globals().keys()):
|
|
163
|
+
if global_key.startswith("__snowflake_internal"):
|
|
164
|
+
del globals()[global_key]
|
|
165
|
+
|
|
166
|
+
del globals()["global_key"] # make sure to clean up the loop variable as well
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
import importlib
|
|
170
|
+
with contextlib.redirect_stdout(None):
|
|
171
|
+
with contextlib.redirect_stderr(None):
|
|
172
|
+
__snowflake_internal_spec = importlib.util.spec_from_file_location("<string>", "{{py_file}}")
|
|
173
|
+
__snowflake_internal_module = importlib.util.module_from_spec(__snowflake_internal_spec)
|
|
174
|
+
__snowflake_internal_spec.loader.exec_module(__snowflake_internal_module)
|
|
175
|
+
except Exception as exc: # Catch any error
|
|
176
|
+
print("An exception occurred while executing file: ", exc, file=sys.stderr)
|
|
177
|
+
sys.exit(1)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
import json
|
|
181
|
+
print(json.dumps(__snowflake_global_collected_extension_fn_json))
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
List,
|
|
7
|
+
Optional,
|
|
8
|
+
Sequence,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
from click.exceptions import ClickException
|
|
12
|
+
from snowflake.cli.api.project.schemas.snowpark.argument import Argument
|
|
13
|
+
from snowflake.cli.api.project.util import (
|
|
14
|
+
is_valid_identifier,
|
|
15
|
+
is_valid_string_literal,
|
|
16
|
+
to_identifier,
|
|
17
|
+
to_string_literal,
|
|
18
|
+
)
|
|
19
|
+
from snowflake.cli.plugins.nativeapp.codegen.snowpark.models import (
|
|
20
|
+
ExtensionFunctionTypeEnum,
|
|
21
|
+
NativeAppExtensionFunction,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MalformedExtensionFunctionError(ClickException):
|
|
26
|
+
"""Required extension function attribute is missing."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, message: str):
|
|
29
|
+
super().__init__(message=message)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_sql_object_type(extension_fn: NativeAppExtensionFunction) -> Optional[str]:
|
|
33
|
+
if extension_fn.function_type == ExtensionFunctionTypeEnum.PROCEDURE:
|
|
34
|
+
return "PROCEDURE"
|
|
35
|
+
elif extension_fn.function_type in (
|
|
36
|
+
ExtensionFunctionTypeEnum.FUNCTION,
|
|
37
|
+
ExtensionFunctionTypeEnum.TABLE_FUNCTION,
|
|
38
|
+
):
|
|
39
|
+
return "FUNCTION"
|
|
40
|
+
elif extension_fn.function_type == extension_fn.function_type.AGGREGATE_FUNCTION:
|
|
41
|
+
return "AGGREGATE FUNCTION"
|
|
42
|
+
else:
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_sql_argument_signature(arg: Argument) -> str:
|
|
47
|
+
formatted = f"{arg.name} {arg.arg_type}"
|
|
48
|
+
if arg.default is not None:
|
|
49
|
+
formatted = f"{formatted} DEFAULT {arg.default}"
|
|
50
|
+
return formatted
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_function_type_signature_for_grant(
|
|
54
|
+
extension_fn: NativeAppExtensionFunction,
|
|
55
|
+
) -> str:
|
|
56
|
+
"""
|
|
57
|
+
Returns the type signature for the specified function, e.g. "int, varchar", suitable for inclusion in a GRANT statement.
|
|
58
|
+
"""
|
|
59
|
+
return ", ".join([arg.arg_type for arg in extension_fn.signature])
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_qualified_object_name(extension_fn: NativeAppExtensionFunction) -> str:
|
|
63
|
+
qualified_name = to_identifier(extension_fn.name)
|
|
64
|
+
if extension_fn.schema_name:
|
|
65
|
+
if is_valid_identifier(extension_fn.schema_name):
|
|
66
|
+
qualified_name = f"{extension_fn.schema_name}.{qualified_name}"
|
|
67
|
+
else:
|
|
68
|
+
full_schema = ".".join(
|
|
69
|
+
[
|
|
70
|
+
to_identifier(schema_part)
|
|
71
|
+
for schema_part in extension_fn.schema_name.split(".")
|
|
72
|
+
]
|
|
73
|
+
)
|
|
74
|
+
qualified_name = f"{full_schema}.{qualified_name}"
|
|
75
|
+
|
|
76
|
+
return qualified_name
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def ensure_string_literal(value: str) -> str:
|
|
80
|
+
"""
|
|
81
|
+
Returns the string literal representation of the given value, or the value itself if
|
|
82
|
+
it was already a valid string literal.
|
|
83
|
+
"""
|
|
84
|
+
if is_valid_string_literal(value):
|
|
85
|
+
return value
|
|
86
|
+
return to_string_literal(value)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def ensure_all_string_literals(values: Sequence[str]) -> List[str]:
|
|
90
|
+
"""
|
|
91
|
+
Ensures that all provided values are valid string literals.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
A list with all values transformed to be valid string literals (as necessary).
|
|
95
|
+
"""
|
|
96
|
+
return [ensure_string_literal(value) for value in values]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class _FunctionDefAccumulator(ast.NodeVisitor):
|
|
100
|
+
"""
|
|
101
|
+
A NodeVisitor that collects AST nodes corresponding to function declarations, filtered by a list
|
|
102
|
+
of wanted functions. This is used to identify all Snowpark extension functions in a module's
|
|
103
|
+
source code.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(self, functions: Sequence[NativeAppExtensionFunction]):
|
|
107
|
+
self._wanted_functions_by_name = {
|
|
108
|
+
fn.handler.split(".")[-1]: fn for fn in functions
|
|
109
|
+
}
|
|
110
|
+
self.definitions: List[Any] = []
|
|
111
|
+
|
|
112
|
+
def visit_FunctionDef(self, node: ast.FunctionDef): # noqa: N802
|
|
113
|
+
if self._want(node):
|
|
114
|
+
self.definitions.append(node)
|
|
115
|
+
self.generic_visit(node)
|
|
116
|
+
|
|
117
|
+
def _want(self, node: Any) -> bool:
|
|
118
|
+
if not node.decorator_list:
|
|
119
|
+
# No decorators for this definition, ignore it
|
|
120
|
+
return False
|
|
121
|
+
|
|
122
|
+
return node.name in self._wanted_functions_by_name
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _get_decorator_id(node: ast.AST) -> Optional[str]:
|
|
126
|
+
"""
|
|
127
|
+
Returns the fully qualified identifier for a decorator, e.g. "foo" or "foo.bar".
|
|
128
|
+
"""
|
|
129
|
+
if isinstance(node, ast.Name):
|
|
130
|
+
return node.id
|
|
131
|
+
elif isinstance(node, ast.Attribute):
|
|
132
|
+
return f"{_get_decorator_id(node.value)}.{node.attr}"
|
|
133
|
+
elif isinstance(node, ast.Call):
|
|
134
|
+
return _get_decorator_id(node.func)
|
|
135
|
+
else:
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _collect_ast_function_definitions(
|
|
140
|
+
tree: ast.AST, extension_functions: Sequence[NativeAppExtensionFunction]
|
|
141
|
+
) -> Sequence[ast.FunctionDef]:
|
|
142
|
+
accumulator = _FunctionDefAccumulator(extension_functions)
|
|
143
|
+
accumulator.visit(tree)
|
|
144
|
+
return accumulator.definitions
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def deannotate_module_source(
|
|
148
|
+
module_source: str,
|
|
149
|
+
extension_functions: Sequence[NativeAppExtensionFunction],
|
|
150
|
+
annotations_to_preserve: Sequence[str] = (),
|
|
151
|
+
) -> str:
|
|
152
|
+
"""
|
|
153
|
+
Removes annotations from a set of specified extension functions.
|
|
154
|
+
|
|
155
|
+
Arguments:
|
|
156
|
+
module_source (str): The source code of the module to deannotate.
|
|
157
|
+
extension_functions (Sequence[NativeAppExtensionFunction]): The list of extension functions
|
|
158
|
+
to deannotate. Other functions encountered will be ignored.
|
|
159
|
+
annotations_to_preserve (Sequence[str], optional): The list of annotations to preserve. The
|
|
160
|
+
names should appear as they are found in the source code, e.g. "foo" for @foo or
|
|
161
|
+
"annotations.bar" for @annotations.bar.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
A de-annotated version of the module source if any match was found. In order to preserve
|
|
165
|
+
line numbers, annotations are simply commented out instead of completely removed.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
tree = ast.parse(module_source)
|
|
169
|
+
|
|
170
|
+
definitions = _collect_ast_function_definitions(tree, extension_functions)
|
|
171
|
+
if not definitions:
|
|
172
|
+
return module_source
|
|
173
|
+
|
|
174
|
+
module_lines = module_source.splitlines()
|
|
175
|
+
for definition in definitions:
|
|
176
|
+
# Comment out all decorators. As per the python grammar, decorators must be terminated by a
|
|
177
|
+
# new line, so the line ranges can't overlap.
|
|
178
|
+
for decorator in definition.decorator_list:
|
|
179
|
+
decorator_id = _get_decorator_id(decorator)
|
|
180
|
+
if decorator_id is None:
|
|
181
|
+
continue
|
|
182
|
+
if annotations_to_preserve and decorator_id in annotations_to_preserve:
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
# AST indices are 1-based
|
|
186
|
+
start_lineno = decorator.lineno - 1
|
|
187
|
+
if decorator.end_lineno is not None:
|
|
188
|
+
end_lineno = decorator.end_lineno - 1
|
|
189
|
+
else:
|
|
190
|
+
end_lineno = start_lineno
|
|
191
|
+
|
|
192
|
+
for lineno in range(start_lineno, end_lineno + 1):
|
|
193
|
+
module_lines[lineno] = "#: " + module_lines[lineno]
|
|
194
|
+
|
|
195
|
+
# we're writing files in text mode, so we should use '\n' regardless of the platform
|
|
196
|
+
return "\n".join(module_lines)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
from pydantic import Field
|
|
7
|
+
from snowflake.cli.api.project.schemas.snowpark.callable import _CallableBase
|
|
8
|
+
from snowflake.cli.api.project.schemas.updatable_model import IdentifierField
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ExtensionFunctionTypeEnum(str, Enum):
|
|
12
|
+
PROCEDURE = "procedure"
|
|
13
|
+
FUNCTION = "function"
|
|
14
|
+
TABLE_FUNCTION = "table function"
|
|
15
|
+
AGGREGATE_FUNCTION = "aggregate function"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class NativeAppExtensionFunction(_CallableBase):
|
|
19
|
+
function_type: ExtensionFunctionTypeEnum = Field(
|
|
20
|
+
title="The type of extension function, one of 'procedure', 'function', 'table function' or 'aggregate function'",
|
|
21
|
+
alias="type",
|
|
22
|
+
)
|
|
23
|
+
lineno: Optional[int] = Field(
|
|
24
|
+
title="The starting line number of the extension function (1-based)",
|
|
25
|
+
default=None,
|
|
26
|
+
)
|
|
27
|
+
name: Optional[str] = Field(
|
|
28
|
+
title="The name of the extension function", default=None
|
|
29
|
+
)
|
|
30
|
+
packages: Optional[List[str]] = Field(
|
|
31
|
+
title="List of packages (with optional version constraints) to be loaded for the function",
|
|
32
|
+
default=[],
|
|
33
|
+
)
|
|
34
|
+
schema_name: Optional[str] = IdentifierField(
|
|
35
|
+
title=f"Name of the schema for the function",
|
|
36
|
+
default=None,
|
|
37
|
+
alias="schema",
|
|
38
|
+
)
|
|
39
|
+
application_roles: Optional[List[str]] = Field(
|
|
40
|
+
title="Application roles granted usage to the function",
|
|
41
|
+
default=[],
|
|
42
|
+
)
|
|
43
|
+
execute_as_caller: Optional[bool] = Field(
|
|
44
|
+
title="Determine whether the procedure is executed with the privileges of "
|
|
45
|
+
"the owner or with the privileges of the caller",
|
|
46
|
+
default=False,
|
|
47
|
+
)
|