snowflake-cli-labs 2.5.0rc3__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. snowflake/cli/__about__.py +1 -1
  2. snowflake/cli/api/cli_global_context.py +31 -3
  3. snowflake/cli/api/commands/decorators.py +21 -6
  4. snowflake/cli/api/commands/flags.py +60 -51
  5. snowflake/cli/api/commands/snow_typer.py +24 -0
  6. snowflake/cli/api/commands/typer_pre_execute.py +26 -0
  7. snowflake/cli/api/console/abc.py +8 -0
  8. snowflake/cli/api/console/console.py +29 -4
  9. snowflake/cli/api/constants.py +3 -0
  10. snowflake/cli/api/project/definition.py +17 -35
  11. snowflake/cli/api/project/definition_manager.py +22 -19
  12. snowflake/cli/api/project/errors.py +9 -6
  13. snowflake/cli/api/project/schemas/identifier_model.py +1 -1
  14. snowflake/cli/api/project/schemas/native_app/application.py +15 -3
  15. snowflake/cli/api/project/schemas/native_app/native_app.py +5 -1
  16. snowflake/cli/api/project/schemas/native_app/path_mapping.py +14 -3
  17. snowflake/cli/api/project/schemas/project_definition.py +37 -6
  18. snowflake/cli/api/project/schemas/streamlit/streamlit.py +3 -0
  19. snowflake/cli/api/project/schemas/updatable_model.py +2 -6
  20. snowflake/cli/api/rest_api.py +113 -0
  21. snowflake/cli/api/sanitizers.py +43 -0
  22. snowflake/cli/api/sql_execution.py +7 -0
  23. snowflake/cli/api/utils/definition_rendering.py +95 -25
  24. snowflake/cli/api/utils/models.py +31 -26
  25. snowflake/cli/api/utils/rendering.py +24 -3
  26. snowflake/cli/app/cli_app.py +2 -0
  27. snowflake/cli/app/commands_registration/command_plugins_loader.py +8 -0
  28. snowflake/cli/app/dev/docs/commands_docs_generator.py +100 -0
  29. snowflake/cli/app/dev/docs/generator.py +8 -67
  30. snowflake/cli/app/dev/docs/project_definition_docs_generator.py +58 -0
  31. snowflake/cli/app/dev/docs/project_definition_generate_json_schema.py +227 -0
  32. snowflake/cli/app/dev/docs/template_utils.py +23 -0
  33. snowflake/cli/app/dev/docs/templates/definition_description.rst.jinja2 +38 -0
  34. snowflake/cli/app/dev/docs/templates/usage.rst.jinja2 +6 -1
  35. snowflake/cli/app/loggers.py +25 -0
  36. snowflake/cli/app/printing.py +7 -5
  37. snowflake/cli/app/telemetry.py +11 -0
  38. snowflake/cli/plugins/nativeapp/artifacts.py +78 -9
  39. snowflake/cli/plugins/nativeapp/codegen/artifact_processor.py +3 -11
  40. snowflake/cli/plugins/nativeapp/codegen/compiler.py +6 -24
  41. snowflake/cli/plugins/nativeapp/codegen/snowpark/python_processor.py +27 -27
  42. snowflake/cli/plugins/nativeapp/commands.py +23 -12
  43. snowflake/cli/plugins/nativeapp/constants.py +2 -0
  44. snowflake/cli/plugins/nativeapp/errno.py +15 -0
  45. snowflake/cli/plugins/nativeapp/feature_flags.py +24 -0
  46. snowflake/cli/plugins/nativeapp/init.py +5 -0
  47. snowflake/cli/plugins/nativeapp/manager.py +101 -103
  48. snowflake/cli/plugins/nativeapp/project_model.py +181 -0
  49. snowflake/cli/plugins/nativeapp/run_processor.py +178 -110
  50. snowflake/cli/plugins/nativeapp/teardown_processor.py +89 -64
  51. snowflake/cli/plugins/nativeapp/utils.py +2 -2
  52. snowflake/cli/plugins/nativeapp/version/commands.py +3 -3
  53. snowflake/cli/plugins/object/commands.py +70 -4
  54. snowflake/cli/plugins/object/manager.py +44 -3
  55. snowflake/cli/plugins/snowpark/commands.py +2 -2
  56. snowflake/cli/plugins/sql/commands.py +2 -10
  57. snowflake/cli/plugins/sql/manager.py +4 -2
  58. snowflake/cli/plugins/stage/commands.py +23 -4
  59. snowflake/cli/plugins/stage/diff.py +81 -51
  60. snowflake/cli/plugins/stage/manager.py +2 -1
  61. snowflake/cli/plugins/streamlit/commands.py +2 -1
  62. snowflake/cli/plugins/streamlit/manager.py +6 -0
  63. {snowflake_cli_labs-2.5.0rc3.dist-info → snowflake_cli_labs-2.6.0.dist-info}/METADATA +15 -9
  64. {snowflake_cli_labs-2.5.0rc3.dist-info → snowflake_cli_labs-2.6.0.dist-info}/RECORD +67 -56
  65. {snowflake_cli_labs-2.5.0rc3.dist-info → snowflake_cli_labs-2.6.0.dist-info}/WHEEL +1 -1
  66. {snowflake_cli_labs-2.5.0rc3.dist-info → snowflake_cli_labs-2.6.0.dist-info}/entry_points.txt +0 -0
  67. {snowflake_cli_labs-2.5.0rc3.dist-info → snowflake_cli_labs-2.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -14,15 +14,16 @@
14
14
 
15
15
  from textwrap import dedent
16
16
 
17
+ from click import ClickException
17
18
  from pydantic import ValidationError
18
19
 
19
20
 
20
- class SchemaValidationError(Exception):
21
- generic_message = "For field {loc} you provided '{loc}'. This caused: {msg}"
21
+ class SchemaValidationError(ClickException):
22
+ generic_message = "For field {location} you provided '{input}'. This caused: {msg}"
22
23
  message_templates = {
23
- "string_type": "{msg} for field '{loc}', you provided '{input}'.",
24
- "extra_forbidden": "{msg}. You provided field '{loc}' with value '{input}' that is not supported in given version.",
25
- "missing": "Your project definition is missing following fields: {loc}",
24
+ "string_type": "{msg} for field '{location}', you provided '{input}'.",
25
+ "extra_forbidden": "{msg}. You provided field '{location}' with value '{input}' that is not supported in given version.",
26
+ "missing": "Your project definition is missing the following field: '{location}'",
26
27
  }
27
28
 
28
29
  def __init__(self, error: ValidationError):
@@ -30,7 +31,9 @@ class SchemaValidationError(Exception):
30
31
  message = f"During evaluation of {error.title} in project definition following errors were encountered:\n"
31
32
  message += "\n".join(
32
33
  [
33
- self.message_templates.get(e["type"], self.generic_message).format(**e)
34
+ self.message_templates.get(e["type"], self.generic_message).format(
35
+ **e, location=".".join(e["loc"]) if e["loc"] is not None else None
36
+ )
34
37
  for e in errors
35
38
  ]
36
39
  )
@@ -36,7 +36,7 @@ def ObjectIdentifierModel(object_name: str) -> ObjectIdentifierBaseModel: # noq
36
36
  """Generates ObjectIdentifierBaseModel but with object specific descriptions."""
37
37
 
38
38
  class _ObjectIdentifierModel(ObjectIdentifierBaseModel):
39
- name: str = Field(title=f"{object_name} name")
39
+ name: str = Field(title=f"{object_name.capitalize()} name")
40
40
  database: Optional[str] = IdentifierField(
41
41
  title=f"Name of the database for the {object_name}", default=None
42
42
  )
@@ -14,7 +14,7 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- from typing import Optional
17
+ from typing import List, Optional
18
18
 
19
19
  from pydantic import Field
20
20
  from snowflake.cli.api.project.schemas.updatable_model import (
@@ -23,6 +23,14 @@ from snowflake.cli.api.project.schemas.updatable_model import (
23
23
  )
24
24
 
25
25
 
26
+ class SqlScriptHookType(UpdatableModel):
27
+ sql_script: str = Field(title="SQL file path relative to the project root")
28
+
29
+
30
+ # Currently sql_script is the only supported hook type. Change to a Union once other hook types are added
31
+ ApplicationPostDeployHook = SqlScriptHookType
32
+
33
+
26
34
  class Application(UpdatableModel):
27
35
  role: Optional[str] = Field(
28
36
  title="Role to use when creating the application object and consumer-side objects",
@@ -37,6 +45,10 @@ class Application(UpdatableModel):
37
45
  default=None,
38
46
  )
39
47
  debug: Optional[bool] = Field(
40
- title="Whether to enable debug mode when using a named stage to create an application object",
41
- default=True,
48
+ title="When set, forces debug_mode on/off for the deployed application object",
49
+ default=None,
50
+ )
51
+ post_deploy: Optional[List[ApplicationPostDeployHook]] = Field(
52
+ title="Actions that will be executed after the application object is created/upgraded",
53
+ default=None,
42
54
  )
@@ -34,8 +34,12 @@ class NativeApp(UpdatableModel):
34
34
  artifacts: List[Union[PathMapping, str]] = Field(
35
35
  title="List of file source and destination pairs to add to the deploy root",
36
36
  )
37
+ bundle_root: Optional[str] = Field(
38
+ title="Folder at the root of your project where artifacts necessary to perform the bundle step are stored.",
39
+ default="output/bundle/",
40
+ )
37
41
  deploy_root: Optional[str] = Field(
38
- title="Folder at the root of your project where the build step copies the artifacts.",
42
+ title="Folder at the root of your project where the bundle step copies the artifacts.",
39
43
  default="output/deploy/",
40
44
  )
41
45
  generated_root: Optional[str] = Field(
@@ -31,9 +31,20 @@ class ProcessorMapping(UpdatableModel):
31
31
 
32
32
 
33
33
  class PathMapping(UpdatableModel):
34
- src: str
35
- dest: Optional[str] = None
36
- processors: Optional[List[Union[str, ProcessorMapping]]] = []
34
+ src: str = Field(
35
+ title="Source path or glob pattern (relative to project root)", default=None
36
+ )
37
+
38
+ dest: Optional[str] = Field(
39
+ title="Destination path on stage",
40
+ description="Paths are relative to stage root; paths ending with a slash indicate that the destination is a directory which source files should be copied into.",
41
+ default=None,
42
+ )
43
+
44
+ processors: Optional[List[Union[str, ProcessorMapping]]] = Field(
45
+ title="List of processors to apply to matching source files during bundling.",
46
+ default=[],
47
+ )
37
48
 
38
49
  @field_validator("processors")
39
50
  @classmethod
@@ -14,18 +14,45 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ from dataclasses import dataclass
17
18
  from typing import Any, Dict, Optional, Union
18
19
 
19
20
  from packaging.version import Version
20
- from pydantic import Field, field_validator
21
+ from pydantic import Field, ValidationError, field_validator
22
+ from snowflake.cli.api.project.errors import SchemaValidationError
21
23
  from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp
22
24
  from snowflake.cli.api.project.schemas.snowpark.snowpark import Snowpark
23
25
  from snowflake.cli.api.project.schemas.streamlit.streamlit import Streamlit
24
26
  from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel
25
- from snowflake.cli.api.utils.models import EnvironWithDefinedDictFallback
27
+ from snowflake.cli.api.utils.models import ProjectEnvironment
28
+ from snowflake.cli.api.utils.types import Context
29
+
30
+
31
+ @dataclass
32
+ class ProjectProperties:
33
+ """
34
+ This class stores 2 objects representing the snowflake project:
35
+
36
+ The project_context object:
37
+ - Used as the context for templating when users reference variables in the project definition file.
38
+
39
+ The project_definition object:
40
+ - This is a transformed object type through Pydantic, which has been normalized.
41
+ - This object could have slightly different structure than what the users see in their yaml project definition files.
42
+ - This should be used for the business logic of snow CLI modules.
43
+ """
44
+
45
+ project_definition: ProjectDefinition
46
+ project_context: Context
26
47
 
27
48
 
28
49
  class _BaseDefinition(UpdatableModel):
50
+ def __init__(self, *args, **kwargs):
51
+ try:
52
+ super().__init__(**kwargs)
53
+ except ValidationError as e:
54
+ raise SchemaValidationError(e) from e
55
+
29
56
  definition_version: Union[str, int] = Field(
30
57
  title="Version of the project definition schema, which is currently 1",
31
58
  )
@@ -58,17 +85,21 @@ class _DefinitionV10(_BaseDefinition):
58
85
 
59
86
 
60
87
  class _DefinitionV11(_DefinitionV10):
61
- env: Optional[Dict] = Field(
88
+ env: Union[Dict[str, str], ProjectEnvironment, None] = Field(
62
89
  title="Environment specification for this project.",
63
90
  default=None,
64
91
  validation_alias="env",
92
+ union_mode="smart",
65
93
  )
66
94
 
67
95
  @field_validator("env")
68
96
  @classmethod
69
- def _convert_env(cls, env: Optional[Dict]) -> EnvironWithDefinedDictFallback:
70
- variables = EnvironWithDefinedDictFallback(env if env else {})
71
- return variables
97
+ def _convert_env(
98
+ cls, env: Union[Dict, ProjectEnvironment, None]
99
+ ) -> ProjectEnvironment:
100
+ if isinstance(env, ProjectEnvironment):
101
+ return env
102
+ return ProjectEnvironment(default_env=(env or {}), override_env={})
72
103
 
73
104
 
74
105
  class ProjectDefinition(_DefinitionV11):
@@ -40,3 +40,6 @@ class Streamlit(UpdatableModel, ObjectIdentifierModel(object_name="Streamlit")):
40
40
  title="List of additional files which should be included into deployment artifacts",
41
41
  default=None,
42
42
  )
43
+ title: Optional[str] = Field(
44
+ title="Human-readable title for the Streamlit dashboard", default=None
45
+ )
@@ -16,8 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  from typing import Any, Dict
18
18
 
19
- from pydantic import BaseModel, ConfigDict, Field, ValidationError
20
- from snowflake.cli.api.project.errors import SchemaValidationError
19
+ from pydantic import BaseModel, ConfigDict, Field
21
20
  from snowflake.cli.api.project.util import IDENTIFIER_NO_LENGTH
22
21
 
23
22
 
@@ -25,10 +24,7 @@ class UpdatableModel(BaseModel):
25
24
  model_config = ConfigDict(validate_assignment=True, extra="forbid")
26
25
 
27
26
  def __init__(self, *args, **kwargs):
28
- try:
29
- super().__init__(**kwargs)
30
- except ValidationError as e:
31
- raise SchemaValidationError(e)
27
+ super().__init__(**kwargs)
32
28
 
33
29
  def update_from_dict(self, update_values: Dict[str, Any]):
34
30
  """
@@ -0,0 +1,113 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import logging
19
+ from typing import Any, Dict, List, Optional
20
+
21
+ from snowflake.cli.api.constants import SF_REST_API_URL_PREFIX
22
+ from snowflake.connector.connection import SnowflakeConnection
23
+ from snowflake.connector.errors import InterfaceError
24
+ from snowflake.connector.network import SnowflakeRestful
25
+
26
+ log = logging.getLogger(__name__)
27
+
28
+
29
+ class RestApi:
30
+ def __init__(self, connection: SnowflakeConnection):
31
+ self.conn = connection
32
+ self.rest: SnowflakeRestful = connection.rest
33
+
34
+ def get_endpoint_exists(self, url: str) -> bool:
35
+ """
36
+ Check whether [get] endpoint exists under given URL.
37
+ """
38
+ try:
39
+ result = self.send_rest_request(url, method="get")
40
+ return bool(result) or result == []
41
+ except InterfaceError as err:
42
+ if "404 Not Found" in str(err):
43
+ return False
44
+ raise err
45
+
46
+ def send_rest_request(
47
+ self, url: str, method: str, data: Optional[Dict[str, Any]] = None
48
+ ):
49
+ """
50
+ Executes rest request via snowflake.connector.network.SnowflakeRestful
51
+ """
52
+ # SnowflakeRestful.request assumes that API response is always a dict,
53
+ # which is not true in case of this API, so we need to do this workaround:
54
+ from snowflake.connector.network import (
55
+ CONTENT_TYPE_APPLICATION_JSON,
56
+ HTTP_HEADER_ACCEPT,
57
+ HTTP_HEADER_CONTENT_TYPE,
58
+ HTTP_HEADER_USER_AGENT,
59
+ PYTHON_CONNECTOR_USER_AGENT,
60
+ )
61
+
62
+ log.debug("Sending %s request to %s", method, url)
63
+ full_url = f"{self.rest.server_url}{url}"
64
+ headers = {
65
+ HTTP_HEADER_CONTENT_TYPE: CONTENT_TYPE_APPLICATION_JSON,
66
+ HTTP_HEADER_ACCEPT: CONTENT_TYPE_APPLICATION_JSON,
67
+ HTTP_HEADER_USER_AGENT: PYTHON_CONNECTOR_USER_AGENT,
68
+ }
69
+ return self.rest.fetch(
70
+ method=method,
71
+ full_url=full_url,
72
+ headers=headers,
73
+ token=self.rest.token,
74
+ data=json.dumps(data if data else {}),
75
+ no_retry=True,
76
+ )
77
+
78
+ def determine_url_for_create_query(
79
+ self, *, plural_object_type: str
80
+ ) -> Optional[str]:
81
+ """
82
+ Determine an url for creating an object of given type via REST API.
83
+ The function returns None if URL cannot be determined.
84
+
85
+ URLs we check:
86
+ * /api/v2/<type>/
87
+ * /api/v2/databases/<database>/<type>/
88
+ * /api/v2/databases/<database>/schemas/<schema>/<type>/
89
+
90
+ We assume that the URLs for CREATE and LIST are the same for every type of object
91
+ (endpoints differ by method: POST vs GET, accordingly).
92
+ To check whether an URL exists, we send read-only GET request (LIST endpoint,
93
+ which should imply CREATE endpoint).
94
+ """
95
+ urls_to_be_checked: List[Optional[str]] = [
96
+ f"{SF_REST_API_URL_PREFIX}/{plural_object_type}/",
97
+ (
98
+ f"{SF_REST_API_URL_PREFIX}/databases/{self.conn.database}/{plural_object_type}/"
99
+ if self.conn.database
100
+ else None
101
+ ),
102
+ (
103
+ f"{SF_REST_API_URL_PREFIX}/databases/{self.conn.database}/schemas/{self.conn.schema}/{plural_object_type}/"
104
+ if self.conn.database and self.conn.schema
105
+ else None
106
+ ),
107
+ ]
108
+
109
+ for url in urls_to_be_checked:
110
+ if url and self.get_endpoint_exists(url):
111
+ return url
112
+
113
+ return None
@@ -0,0 +1,43 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import re
18
+
19
+ # 7-bit C1 ANSI sequences
20
+ _ANSI_ESCAPE = re.compile(
21
+ r"""
22
+ \x1B # ESC
23
+ (?: # 7-bit C1 Fe (except CSI)
24
+ [@-Z\\-_]
25
+ | # or [ for CSI, followed by a control sequence
26
+ \[
27
+ [0-?]* # Parameter bytes
28
+ [ -/]* # Intermediate bytes
29
+ [@-~] # Final byte
30
+ )
31
+ """,
32
+ re.VERBOSE,
33
+ )
34
+
35
+
36
+ def sanitize_for_terminal(text: str) -> str | None:
37
+ """
38
+ Escape ASCII escape codes in string. This should be always used
39
+ when printing output to terminal.
40
+ """
41
+ if text is None:
42
+ return None
43
+ return _ANSI_ESCAPE.sub("", text)
@@ -22,6 +22,7 @@ from textwrap import dedent
22
22
  from typing import Iterable, Optional, Tuple
23
23
 
24
24
  from snowflake.cli.api.cli_global_context import cli_context
25
+ from snowflake.cli.api.console import cli_console
25
26
  from snowflake.cli.api.constants import ObjectType
26
27
  from snowflake.cli.api.exceptions import (
27
28
  DatabaseNotProvidedError,
@@ -212,3 +213,9 @@ class SqlExecutionMixin:
212
213
  lambda row: row[name_col] == unquote_identifier(unqualified_name),
213
214
  )
214
215
  return show_obj_row
216
+
217
+
218
+ class VerboseCursor(SnowflakeCursor):
219
+ def execute(self, command: str, *args, **kwargs):
220
+ cli_console.message(command)
221
+ super().execute(command, *args, **kwargs)
@@ -15,15 +15,19 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import copy
18
- import os
19
18
  from typing import Any, Optional
20
19
 
21
- from jinja2 import Environment, nodes
20
+ from jinja2 import Environment, TemplateSyntaxError, nodes
22
21
  from packaging.version import Version
22
+ from snowflake.cli.api.console import cli_console as cc
23
23
  from snowflake.cli.api.exceptions import CycleDetectedError, InvalidTemplate
24
+ from snowflake.cli.api.project.schemas.project_definition import (
25
+ ProjectDefinition,
26
+ ProjectProperties,
27
+ )
24
28
  from snowflake.cli.api.utils.dict_utils import traverse
25
29
  from snowflake.cli.api.utils.graph import Graph, Node
26
- from snowflake.cli.api.utils.models import EnvironWithDefinedDictFallback
30
+ from snowflake.cli.api.utils.models import ProjectEnvironment
27
31
  from snowflake.cli.api.utils.rendering import CONTEXT_KEY, get_snowflake_cli_jinja_env
28
32
  from snowflake.cli.api.utils.types import Context, Definition
29
33
 
@@ -46,7 +50,13 @@ class TemplatedEnvironment:
46
50
 
47
51
  def get_referenced_vars(self, template_value: Any) -> set[TemplateVar]:
48
52
  template_str = str(template_value)
49
- ast = self._jinja_env.parse(template_str)
53
+ try:
54
+ ast = self._jinja_env.parse(template_str)
55
+ except TemplateSyntaxError as e:
56
+ raise InvalidTemplate(
57
+ f"Error parsing template from project definition file. Value: '{template_str}'. Error: {e}"
58
+ ) from e
59
+
50
60
  return self._get_referenced_vars(ast, template_str)
51
61
 
52
62
  def _get_referenced_vars(
@@ -159,9 +169,13 @@ class TemplateVar:
159
169
  current_dict_level = current_dict_level[key]
160
170
 
161
171
  value = current_dict_level
162
- if value is None or isinstance(value, (dict, list)):
172
+
173
+ if value is None:
174
+ raise InvalidTemplate(f"Template variable {self.key} does not have a value")
175
+
176
+ if isinstance(value, (dict, list)):
163
177
  raise InvalidTemplate(
164
- f"Template variable {self.key} does not contain a valid value"
178
+ f"Template variable {self.key} does not have a scalar value"
165
179
  )
166
180
 
167
181
  return value
@@ -174,17 +188,20 @@ class TemplateVar:
174
188
 
175
189
 
176
190
  def _build_dependency_graph(
177
- env: TemplatedEnvironment, all_vars: set[TemplateVar], context: Context
191
+ env: TemplatedEnvironment,
192
+ all_vars: set[TemplateVar],
193
+ context: Context,
194
+ environment_overrides: ProjectEnvironment,
178
195
  ) -> Graph[TemplateVar]:
179
196
  dependencies_graph = Graph[TemplateVar]()
180
197
  for variable in all_vars:
181
198
  dependencies_graph.add(Node[TemplateVar](key=variable.key, data=variable))
182
199
 
183
200
  for variable in all_vars:
184
- if variable.is_env_var and variable.get_env_var_name() in os.environ:
185
- # If variable is found in os.environ, then use the value as is
186
- # skip rendering by pre-setting the rendered_value attribute
187
- env_value = os.environ.get(variable.get_env_var_name())
201
+ # If variable is found in os.environ or from cli override, then use the value as is
202
+ # skip rendering by pre-setting the rendered_value attribute
203
+ if variable.is_env_var and variable.get_env_var_name() in environment_overrides:
204
+ env_value = environment_overrides.get(variable.get_env_var_name())
188
205
  variable.rendered_value = env_value
189
206
  variable.templated_value = env_value
190
207
  else:
@@ -210,7 +227,41 @@ def _render_graph_node(env: TemplatedEnvironment, node: Node[TemplateVar]) -> No
210
227
  node.data.rendered_value = env.render(node.data.templated_value, current_context)
211
228
 
212
229
 
213
- def render_definition_template(original_definition: Definition) -> Definition:
230
+ def _validate_env_section(env_section: dict):
231
+ if not isinstance(env_section, dict):
232
+ raise InvalidTemplate(
233
+ "env section in project definition file should be a mapping"
234
+ )
235
+ for variable, value in env_section.items():
236
+ if value is None or isinstance(value, (dict, list)):
237
+ raise InvalidTemplate(
238
+ f"Variable {variable} in env section of project definition file should be a scalar"
239
+ )
240
+
241
+
242
+ def _get_referenced_vars_in_definition(
243
+ template_env: TemplatedEnvironment, definition: Definition
244
+ ):
245
+ referenced_vars = set()
246
+
247
+ def find_any_template_vars(element):
248
+ referenced_vars.update(template_env.get_referenced_vars(element))
249
+
250
+ traverse(definition, visit_action=find_any_template_vars)
251
+
252
+ return referenced_vars
253
+
254
+
255
+ def _template_version_warning():
256
+ cc.warning(
257
+ "Ignoring template pattern in project definition file. "
258
+ "Update 'definition_version' to 1.1 or later in snowflake.yml to enable template expansion."
259
+ )
260
+
261
+
262
+ def render_definition_template(
263
+ original_definition: Optional[Definition], context_overrides: Context
264
+ ) -> ProjectProperties:
214
265
  """
215
266
  Takes a definition file as input. An arbitrary structure containing dict|list|scalars,
216
267
  with the top level being a dictionary.
@@ -225,23 +276,42 @@ def render_definition_template(original_definition: Definition) -> Definition:
225
276
  # protect input from update
226
277
  definition = copy.deepcopy(original_definition)
227
278
 
279
+ # start with an environment from overrides and environment variables:
280
+ override_env = context_overrides.get(CONTEXT_KEY, {}).get("env", {})
281
+ environment_overrides = ProjectEnvironment(
282
+ default_env={}, override_env=override_env
283
+ )
284
+
285
+ if definition is None:
286
+ return ProjectProperties(None, {CONTEXT_KEY: {"env": environment_overrides}})
287
+
288
+ project_context = {CONTEXT_KEY: definition}
289
+ template_env = TemplatedEnvironment(get_snowflake_cli_jinja_env())
290
+
228
291
  if "definition_version" not in definition or Version(
229
292
  definition["definition_version"]
230
293
  ) < Version("1.1"):
231
- return definition
294
+ try:
295
+ referenced_vars = _get_referenced_vars_in_definition(
296
+ template_env, definition
297
+ )
298
+ if referenced_vars:
299
+ _template_version_warning()
300
+ except Exception:
301
+ # also warn on Exception, as it means the user is incorrectly attempting to use templating
302
+ _template_version_warning()
232
303
 
233
- template_env = TemplatedEnvironment(get_snowflake_cli_jinja_env())
234
- project_context = {CONTEXT_KEY: definition}
304
+ project_definition = ProjectDefinition(**original_definition)
305
+ project_context[CONTEXT_KEY]["env"] = environment_overrides
306
+ return ProjectProperties(project_definition, project_context)
235
307
 
236
- referenced_vars = set()
308
+ default_env = definition.get("env", {})
309
+ _validate_env_section(default_env)
237
310
 
238
- def find_any_template_vars(element):
239
- referenced_vars.update(template_env.get_referenced_vars(element))
240
-
241
- traverse(definition, visit_action=find_any_template_vars)
311
+ referenced_vars = _get_referenced_vars_in_definition(template_env, definition)
242
312
 
243
313
  dependencies_graph = _build_dependency_graph(
244
- template_env, referenced_vars, project_context
314
+ template_env, referenced_vars, project_context, environment_overrides
245
315
  )
246
316
 
247
317
  def on_cycle_action(node: Node[TemplateVar]):
@@ -265,7 +335,7 @@ def render_definition_template(original_definition: Definition) -> Definition:
265
335
  update_action=lambda val: template_env.render(val, final_context),
266
336
  )
267
337
 
268
- current_env = definition.setdefault("env", {})
269
- definition["env"] = EnvironWithDefinedDictFallback(current_env)
270
-
271
- return definition
338
+ definition["env"] = ProjectEnvironment(default_env, override_env)
339
+ project_context[CONTEXT_KEY] = definition
340
+ project_definition = ProjectDefinition(**definition)
341
+ return ProjectProperties(project_definition, project_context)
@@ -15,41 +15,46 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import os
18
- from typing import Any, Dict
18
+ from typing import Any, Dict, Optional
19
19
 
20
- from snowflake.cli.api.exceptions import InvalidTemplate
20
+ from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel
21
21
 
22
22
 
23
- def _validate_env(current_env: dict):
24
- if not isinstance(current_env, dict):
25
- raise InvalidTemplate(
26
- "env section in project definition file should be a mapping"
27
- )
28
- for variable, value in current_env.items():
29
- if value is None or isinstance(value, (dict, list)):
30
- raise InvalidTemplate(
31
- f"Variable {variable} in env section or project definition file should be a scalar"
32
- )
23
+ class ProjectEnvironment(UpdatableModel):
24
+ """
25
+ This class handles retrieval of project env variables.
26
+ These env variables can be accessed through templating, as ctx.env.<var_name>
33
27
 
28
+ This class checks for env values in the following order:
29
+ - Check for overrides values from the command line. Use these values first.
30
+ - Check if these variables are available as environment variables and return them if found.
31
+ - Check for default values from the project definition file.
32
+ """
34
33
 
35
- class EnvironWithDefinedDictFallback(Dict):
36
- def __init__(self, dict_input: dict):
37
- _validate_env(dict_input)
38
- super().__init__(dict_input)
34
+ override_env: Dict[str, Any] = {}
35
+ default_env: Dict[str, Any] = {}
39
36
 
40
- def __getattr__(self, item):
41
- try:
42
- return self[item]
43
- except KeyError as e:
44
- raise AttributeError(e)
37
+ def __init__(
38
+ self, default_env: Dict[str, Any], override_env: Optional[Dict[str, Any]] = None
39
+ ):
40
+ super().__init__(self, default_env=default_env, override_env=override_env or {})
45
41
 
46
42
  def __getitem__(self, item):
43
+ if item in self.override_env:
44
+ return self.override_env.get(item)
47
45
  if item in os.environ:
48
46
  return os.environ[item]
49
- return super().__getitem__(item)
47
+ return self.default_env[item]
50
48
 
51
- def __contains__(self, item):
52
- return item in os.environ or super().__contains__(item)
49
+ def get(self, item, default=None):
50
+ try:
51
+ return self[item]
52
+ except KeyError:
53
+ return default
53
54
 
54
- def update_from_dict(self, update_values: Dict[str, Any]):
55
- return super().update(update_values)
55
+ def __contains__(self, item) -> bool:
56
+ try:
57
+ self[item]
58
+ return True
59
+ except KeyError:
60
+ return False