snowflake-cli-labs 3.0.0rc1__py3-none-any.whl → 3.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. snowflake/cli/__about__.py +1 -1
  2. snowflake/cli/_app/cli_app.py +10 -1
  3. snowflake/cli/_app/snow_connector.py +76 -29
  4. snowflake/cli/_app/telemetry.py +8 -4
  5. snowflake/cli/_app/version_check.py +74 -0
  6. snowflake/cli/_plugins/git/commands.py +55 -14
  7. snowflake/cli/_plugins/nativeapp/codegen/snowpark/python_processor.py +2 -5
  8. snowflake/cli/_plugins/nativeapp/codegen/templates/templates_processor.py +49 -31
  9. snowflake/cli/_plugins/nativeapp/manager.py +46 -87
  10. snowflake/cli/_plugins/nativeapp/run_processor.py +56 -260
  11. snowflake/cli/_plugins/nativeapp/same_account_install_method.py +74 -0
  12. snowflake/cli/_plugins/nativeapp/teardown_processor.py +9 -152
  13. snowflake/cli/_plugins/nativeapp/v2_conversions/v2_to_v1_decorator.py +91 -17
  14. snowflake/cli/_plugins/snowpark/commands.py +1 -1
  15. snowflake/cli/_plugins/snowpark/models.py +2 -1
  16. snowflake/cli/_plugins/streamlit/commands.py +1 -1
  17. snowflake/cli/_plugins/streamlit/manager.py +9 -0
  18. snowflake/cli/_plugins/workspace/action_context.py +2 -1
  19. snowflake/cli/_plugins/workspace/commands.py +48 -16
  20. snowflake/cli/_plugins/workspace/manager.py +1 -0
  21. snowflake/cli/api/cli_global_context.py +136 -313
  22. snowflake/cli/api/commands/flags.py +76 -91
  23. snowflake/cli/api/commands/snow_typer.py +6 -4
  24. snowflake/cli/api/config.py +1 -1
  25. snowflake/cli/api/connections.py +214 -0
  26. snowflake/cli/api/console/abc.py +4 -2
  27. snowflake/cli/api/entities/application_entity.py +687 -2
  28. snowflake/cli/api/entities/application_package_entity.py +151 -46
  29. snowflake/cli/api/entities/common.py +1 -0
  30. snowflake/cli/api/entities/utils.py +41 -17
  31. snowflake/cli/api/identifiers.py +3 -0
  32. snowflake/cli/api/project/definition.py +11 -0
  33. snowflake/cli/api/project/definition_conversion.py +171 -13
  34. snowflake/cli/api/project/schemas/entities/common.py +0 -12
  35. snowflake/cli/api/project/schemas/identifier_model.py +2 -2
  36. snowflake/cli/api/project/schemas/project_definition.py +101 -39
  37. snowflake/cli/api/rendering/project_definition_templates.py +4 -0
  38. snowflake/cli/api/rendering/sql_templates.py +7 -0
  39. snowflake/cli/api/utils/definition_rendering.py +3 -1
  40. {snowflake_cli_labs-3.0.0rc1.dist-info → snowflake_cli_labs-3.0.0rc2.dist-info}/METADATA +6 -6
  41. {snowflake_cli_labs-3.0.0rc1.dist-info → snowflake_cli_labs-3.0.0rc2.dist-info}/RECORD +44 -42
  42. snowflake/cli/api/commands/typer_pre_execute.py +0 -26
  43. {snowflake_cli_labs-3.0.0rc1.dist-info → snowflake_cli_labs-3.0.0rc2.dist-info}/WHEEL +0 -0
  44. {snowflake_cli_labs-3.0.0rc1.dist-info → snowflake_cli_labs-3.0.0rc2.dist-info}/entry_points.txt +0 -0
  45. {snowflake_cli_labs-3.0.0rc1.dist-info → snowflake_cli_labs-3.0.0rc2.dist-info}/licenses/LICENSE +0 -0
@@ -14,4 +14,4 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- VERSION = "3.0.0rc1"
17
+ VERSION = "3.0.0rc2"
@@ -38,6 +38,10 @@ from snowflake.cli._app.dev.pycharm_remote_debug import (
38
38
  )
39
39
  from snowflake.cli._app.main_typer import SnowCliMainTyper
40
40
  from snowflake.cli._app.printing import MessageResult, print_result
41
+ from snowflake.cli._app.version_check import (
42
+ get_new_version_msg,
43
+ show_new_version_banner_callback,
44
+ )
41
45
  from snowflake.cli.api import Api, api_provider
42
46
  from snowflake.cli.api.config import config_init
43
47
  from snowflake.cli.api.output.formats import OutputFormat
@@ -145,8 +149,13 @@ def _info_callback(value: bool):
145
149
 
146
150
  def app_factory() -> SnowCliMainTyper:
147
151
  app = SnowCliMainTyper()
152
+ new_version_msg = get_new_version_msg()
148
153
 
149
- @app.callback(invoke_without_command=True)
154
+ @app.callback(
155
+ invoke_without_command=True,
156
+ epilog=new_version_msg,
157
+ result_callback=show_new_version_banner_callback(new_version_msg),
158
+ )
150
159
  def default(
151
160
  ctx: typer.Context,
152
161
  version: bool = typer.Option(
@@ -25,11 +25,9 @@ from snowflake.cli._app.constants import (
25
25
  PARAM_APPLICATION_NAME,
26
26
  )
27
27
  from snowflake.cli._app.telemetry import command_info
28
- from snowflake.cli.api.cli_global_context import get_cli_context
29
28
  from snowflake.cli.api.config import (
30
29
  get_connection_dict,
31
- get_default_connection_dict,
32
- get_default_connection_name,
30
+ get_env_value,
33
31
  )
34
32
  from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB
35
33
  from snowflake.cli.api.exceptions import (
@@ -46,6 +44,33 @@ log = logging.getLogger(__name__)
46
44
  ENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN ENCRYPTED PRIVATE KEY-----"
47
45
  UNENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN PRIVATE KEY-----"
48
46
 
47
+ # connection keys that can be set using SNOWFLAKE_* env vars
48
+ SUPPORTED_ENV_OVERRIDES = [
49
+ "account",
50
+ "user",
51
+ "password",
52
+ "authenticator",
53
+ "private_key_file",
54
+ "private_key_path",
55
+ "database",
56
+ "schema",
57
+ "role",
58
+ "warehouse",
59
+ "session_token",
60
+ "master_token",
61
+ "token_file_path",
62
+ ]
63
+
64
+ # mapping of found key -> key to set
65
+ CONNECTION_KEY_ALIASES = {"private_key_path": "private_key_file"}
66
+
67
+
68
+ def _resolve_alias(key_or_alias: str):
69
+ """
70
+ Given the key of an override / env var, what key should it be set as in the connection parameters?
71
+ """
72
+ return CONNECTION_KEY_ALIASES.get(key_or_alias, key_or_alias)
73
+
49
74
 
50
75
  def connect_to_snowflake(
51
76
  temporary_connection: bool = False,
@@ -58,6 +83,10 @@ def connect_to_snowflake(
58
83
  ) -> SnowflakeConnection:
59
84
  if temporary_connection and connection_name:
60
85
  raise ClickException("Can't use connection name and temporary connection.")
86
+ elif not temporary_connection and not connection_name:
87
+ raise ClickException(
88
+ "One of connection name or temporary connection is required."
89
+ )
61
90
 
62
91
  using_session_token = (
63
92
  "session_token" in overrides and overrides["session_token"] is not None
@@ -70,37 +99,33 @@ def connect_to_snowflake(
70
99
  )
71
100
 
72
101
  if connection_name:
73
- connection_parameters = get_connection_dict(connection_name)
74
- connection_parameters = get_connection_dict(connection_name)
102
+ connection_parameters = {
103
+ _resolve_alias(k): v
104
+ for k, v in get_connection_dict(connection_name).items()
105
+ }
75
106
  elif temporary_connection:
76
107
  connection_parameters = {} # we will apply overrides in next step
77
- else:
78
- connection_parameters = get_default_connection_dict()
79
- get_cli_context().connection_context.set_connection_name(
80
- get_default_connection_name()
81
- )
82
108
 
83
109
  # Apply overrides to connection details
110
+ # (1) Command line override case
84
111
  for key, value in overrides.items():
85
- # Command line override case
86
- if value:
87
- connection_parameters[key] = value
88
- continue
112
+ if value is not None:
113
+ connection_parameters[_resolve_alias(key)] = value
89
114
 
90
- # Generic environment variable case, apply only if value not passed via flag or connection variable
91
- generic_env_value = os.environ.get(f"SNOWFLAKE_{key}".upper())
92
- if key not in connection_parameters and generic_env_value:
93
- connection_parameters[key] = generic_env_value
94
- continue
115
+ # (2) Generic environment variable case
116
+ # ... apply only if value not passed via flag or connection variable
117
+ for key in SUPPORTED_ENV_OVERRIDES:
118
+ generic_env_value = get_env_value(key=key)
119
+ connection_key = _resolve_alias(key)
120
+ if connection_key not in connection_parameters and generic_env_value:
121
+ connection_parameters[connection_key] = generic_env_value
95
122
 
96
123
  # Clean up connection params
97
124
  connection_parameters = {
98
125
  k: v for k, v in connection_parameters.items() if v is not None
99
126
  }
100
127
 
101
- connection_parameters = update_connection_details_with_private_key(
102
- connection_parameters
103
- )
128
+ update_connection_details_with_private_key(connection_parameters)
104
129
 
105
130
  if mfa_passcode:
106
131
  connection_parameters["passcode"] = mfa_passcode
@@ -169,12 +194,32 @@ def update_connection_details_with_private_key(connection_parameters: Dict):
169
194
  _load_private_key(connection_parameters, "private_key_file")
170
195
  elif "private_key_path" in connection_parameters:
171
196
  _load_private_key(connection_parameters, "private_key_path")
197
+ elif "private_key_raw" in connection_parameters:
198
+ _load_private_key_from_parameters(connection_parameters, "private_key_raw")
172
199
  return connection_parameters
173
200
 
174
201
 
175
202
  def _load_private_key(connection_parameters: Dict, private_key_var_name: str) -> None:
176
203
  if connection_parameters.get("authenticator") == "SNOWFLAKE_JWT":
177
- private_key = _load_pem_to_der(connection_parameters[private_key_var_name])
204
+ private_key_pem = _load_pem_from_file(
205
+ connection_parameters[private_key_var_name]
206
+ )
207
+ private_key = _load_pem_to_der(private_key_pem)
208
+ connection_parameters["private_key"] = private_key
209
+ del connection_parameters[private_key_var_name]
210
+ else:
211
+ raise ClickException(
212
+ "Private Key authentication requires authenticator set to SNOWFLAKE_JWT"
213
+ )
214
+
215
+
216
+ def _load_private_key_from_parameters(
217
+ connection_parameters: Dict, private_key_var_name: str
218
+ ) -> None:
219
+ if connection_parameters.get("authenticator") == "SNOWFLAKE_JWT":
220
+ private_key_pem = connection_parameters[private_key_var_name]
221
+ private_key_pem = private_key_pem.encode("utf-8")
222
+ private_key = _load_pem_to_der(private_key_pem)
178
223
  connection_parameters["private_key"] = private_key
179
224
  del connection_parameters[private_key_var_name]
180
225
  else:
@@ -191,17 +236,19 @@ def _update_connection_application_name(connection_parameters: Dict):
191
236
  connection_parameters.update(connection_application_params)
192
237
 
193
238
 
194
- def _load_pem_to_der(private_key_file: str) -> bytes:
195
- """
196
- Given a private key file path (in PEM format), decode key data into DER
197
- format
198
- """
199
-
239
+ def _load_pem_from_file(private_key_file: str) -> bytes:
200
240
  with SecurePath(private_key_file).open(
201
241
  "rb", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB
202
242
  ) as f:
203
243
  private_key_pem = f.read()
244
+ return private_key_pem
204
245
 
246
+
247
+ def _load_pem_to_der(private_key_pem: bytes) -> bytes:
248
+ """
249
+ Given a private key file path (in PEM format), decode key data into DER
250
+ format
251
+ """
205
252
  private_key_passphrase = os.getenv("PRIVATE_KEY_PASSPHRASE", None)
206
253
  if (
207
254
  private_key_pem.startswith(ENCRYPTED_PKCS8_PK_HEADER)
@@ -22,7 +22,10 @@ from typing import Any, Dict, Union
22
22
  import click
23
23
  from snowflake.cli.__about__ import VERSION
24
24
  from snowflake.cli._app.constants import PARAM_APPLICATION_NAME
25
- from snowflake.cli.api.cli_global_context import get_cli_context
25
+ from snowflake.cli.api.cli_global_context import (
26
+ _CliGlobalContextAccess,
27
+ get_cli_context,
28
+ )
26
29
  from snowflake.cli.api.commands.execution_metadata import ExecutionMetadata
27
30
  from snowflake.cli.api.config import get_feature_flags_section
28
31
  from snowflake.cli.api.output.formats import OutputFormat
@@ -106,8 +109,9 @@ def python_version() -> str:
106
109
 
107
110
 
108
111
  class CLITelemetryClient:
109
- def __init__(self, ctx):
110
- self._ctx = ctx
112
+ @property
113
+ def _ctx(self) -> _CliGlobalContextAccess:
114
+ return get_cli_context()
111
115
 
112
116
  @staticmethod
113
117
  def generate_telemetry_data_dict(
@@ -143,7 +147,7 @@ class CLITelemetryClient:
143
147
  self._telemetry.send_batch()
144
148
 
145
149
 
146
- _telemetry = CLITelemetryClient(ctx=get_cli_context())
150
+ _telemetry = CLITelemetryClient()
147
151
 
148
152
 
149
153
  @ignore_exceptions()
@@ -0,0 +1,74 @@
1
+ import json
2
+ import time
3
+
4
+ import requests
5
+ from packaging.version import Version
6
+ from snowflake.cli.__about__ import VERSION
7
+ from snowflake.cli.api.console import cli_console
8
+ from snowflake.cli.api.secure_path import SecurePath
9
+ from snowflake.connector.config_manager import CONFIG_MANAGER
10
+
11
+
12
+ def get_new_version_msg() -> str | None:
13
+ last = _VersionCache().get_last_version()
14
+ current = Version(VERSION)
15
+ if last and last > current:
16
+ return f"\nNew version of Snowflake CLI available. Newest: {last}, current: {VERSION}\n"
17
+ return None
18
+
19
+
20
+ def show_new_version_banner_callback(msg):
21
+ def _callback(*args, **kwargs):
22
+ if msg:
23
+ cli_console.message(msg)
24
+
25
+ return _callback
26
+
27
+
28
+ class _VersionCache:
29
+ _last_time = "last_time_check"
30
+ _version = "version"
31
+ _version_cache_file = SecurePath(
32
+ CONFIG_MANAGER.file_path.parent / ".cli_version.cache"
33
+ )
34
+
35
+ def __init__(self):
36
+ self._cache_file = _VersionCache._version_cache_file
37
+
38
+ def _save_latest_version(self, version: str):
39
+ data = {
40
+ _VersionCache._last_time: time.time(),
41
+ _VersionCache._version: str(version),
42
+ }
43
+ self._cache_file.write_text(json.dumps(data))
44
+
45
+ @staticmethod
46
+ def _get_version_from_pypi() -> str | None:
47
+ headers = {"Content-Type": "application/vnd.pypi.simple.v1+json"}
48
+ response = requests.get(
49
+ "https://pypi.org/pypi/snowflake-cli-labs/json", headers=headers, timeout=3
50
+ )
51
+ response.raise_for_status()
52
+ return response.json()["info"]["version"]
53
+
54
+ def _update_latest_version(self) -> Version | None:
55
+ version = self._get_version_from_pypi()
56
+ if version is None:
57
+ return None
58
+ self._save_latest_version(version)
59
+ return Version(version)
60
+
61
+ def _read_latest_version(self) -> Version | None:
62
+ if self._cache_file.exists():
63
+ data = json.loads(self._cache_file.read_text())
64
+ now = time.time()
65
+ if data[_VersionCache._last_time] > now - 60 * 60:
66
+ return Version(data[_VersionCache._version])
67
+
68
+ return self._update_latest_version()
69
+
70
+ def get_last_version(self) -> Version | None:
71
+ try:
72
+ return self._read_latest_version()
73
+ except: # anything, this it not crucial feature
74
+ return None
@@ -18,7 +18,7 @@ import itertools
18
18
  import logging
19
19
  from os import path
20
20
  from pathlib import Path
21
- from typing import List, Optional
21
+ from typing import Dict, List, Optional
22
22
 
23
23
  import typer
24
24
  from click import ClickException
@@ -41,6 +41,7 @@ from snowflake.cli.api.console.console import cli_console
41
41
  from snowflake.cli.api.constants import ObjectType
42
42
  from snowflake.cli.api.output.types import CollectionResult, CommandResult, QueryResult
43
43
  from snowflake.cli.api.utils.path_utils import is_stage_path
44
+ from snowflake.connector import DictCursor
44
45
 
45
46
  app = SnowTyperFactory(
46
47
  name="git",
@@ -98,6 +99,24 @@ def _validate_origin_url(url: str) -> None:
98
99
  raise ClickException("Url address should start with 'https'")
99
100
 
100
101
 
102
+ def _unique_new_object_name(
103
+ om: ObjectManager, object_type: ObjectType, proposed_fqn: FQN
104
+ ) -> str:
105
+ existing_objects: List[Dict] = om.show(
106
+ object_type=object_type.value.cli_name,
107
+ like=f"{proposed_fqn.name}%",
108
+ cursor_class=DictCursor,
109
+ ).fetchall()
110
+ existing_names = set(o["name"].upper() for o in existing_objects)
111
+
112
+ result = proposed_fqn.name
113
+ i = 1
114
+ while result.upper() in existing_names:
115
+ result = proposed_fqn.name + str(i)
116
+ i += 1
117
+ return result
118
+
119
+
101
120
  @app.command("setup", requires_connection=True)
102
121
  def setup(
103
122
  repository_name: FQN = RepoNameArgument,
@@ -128,13 +147,29 @@ def setup(
128
147
  should_create_secret = False
129
148
  secret_name = None
130
149
  if secret_needed:
131
- secret_name = f"{repository_name}_secret"
132
- secret_name = typer.prompt(
133
- "Secret identifier (will be created if not exists)", default=secret_name
150
+ default_secret_name = (
151
+ FQN.from_string(f"{repository_name.name}_secret")
152
+ .set_schema(repository_name.schema)
153
+ .set_database(repository_name.database)
154
+ )
155
+ default_secret_name.set_name(
156
+ _unique_new_object_name(
157
+ om, object_type=ObjectType.SECRET, proposed_fqn=default_secret_name
158
+ ),
134
159
  )
135
- secret_fqn = FQN.from_string(secret_name)
160
+ secret_name = FQN.from_string(
161
+ typer.prompt(
162
+ "Secret identifier (will be created if not exists)",
163
+ default=default_secret_name.name,
164
+ )
165
+ )
166
+ if not secret_name.database:
167
+ secret_name.set_database(repository_name.database)
168
+ if not secret_name.schema:
169
+ secret_name.set_schema(repository_name.schema)
170
+
136
171
  if om.object_exists(
137
- object_type=ObjectType.SECRET.value.cli_name, fqn=secret_fqn
172
+ object_type=ObjectType.SECRET.value.cli_name, fqn=secret_name
138
173
  ):
139
174
  cli_console.step(f"Using existing secret '{secret_name}'")
140
175
  else:
@@ -143,24 +178,30 @@ def setup(
143
178
  secret_username = typer.prompt("username")
144
179
  secret_password = typer.prompt("password/token", hide_input=True)
145
180
 
146
- api_integration = f"{repository_name}_api_integration"
147
- api_integration = typer.prompt(
148
- "API integration identifier (will be created if not exists)",
149
- default=api_integration,
181
+ # API integration is an account-level object
182
+ api_integration = FQN.from_string(f"{repository_name.name}_api_integration")
183
+ api_integration.set_name(
184
+ typer.prompt(
185
+ "API integration identifier (will be created if not exists)",
186
+ default=_unique_new_object_name(
187
+ om,
188
+ object_type=ObjectType.INTEGRATION,
189
+ proposed_fqn=api_integration,
190
+ ),
191
+ )
150
192
  )
151
- api_integration_fqn = FQN.from_string(api_integration)
152
193
 
153
194
  if should_create_secret:
154
195
  manager.create_password_secret(
155
- name=secret_fqn, username=secret_username, password=secret_password
196
+ name=secret_name, username=secret_username, password=secret_password
156
197
  )
157
198
  cli_console.step(f"Secret '{secret_name}' successfully created.")
158
199
 
159
200
  if not om.object_exists(
160
- object_type=ObjectType.INTEGRATION.value.cli_name, fqn=api_integration_fqn
201
+ object_type=ObjectType.INTEGRATION.value.cli_name, fqn=api_integration
161
202
  ):
162
203
  manager.create_api_integration(
163
- name=api_integration_fqn,
204
+ name=api_integration,
164
205
  api_provider="git_https_api",
165
206
  allowed_prefix=url,
166
207
  secret=secret_name,
@@ -323,11 +323,8 @@ class SnowparkAnnotationProcessor(ArtifactProcessor):
323
323
  predicate=is_python_file_artifact,
324
324
  )
325
325
  ):
326
- cc.step(
327
- "Processing Snowpark annotations from {}".format(
328
- dest_file.relative_to(bundle_map.deploy_root())
329
- )
330
- )
326
+ src_file_name = src_file.relative_to(self._bundle_ctx.project_root)
327
+ cc.step(f"Processing Snowpark annotations from {src_file_name}")
331
328
  collected_extension_function_json = _execute_in_sandbox(
332
329
  py_file=str(dest_file.resolve()),
333
330
  deploy_root=self._bundle_ctx.deploy_root,
@@ -14,6 +14,7 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ from pathlib import Path
17
18
  from typing import Optional
18
19
 
19
20
  import jinja2
@@ -30,27 +31,72 @@ from snowflake.cli.api.project.schemas.native_app.path_mapping import (
30
31
  )
31
32
  from snowflake.cli.api.rendering.project_definition_templates import (
32
33
  get_client_side_jinja_env,
34
+ has_client_side_templates,
33
35
  )
34
36
  from snowflake.cli.api.rendering.sql_templates import (
35
37
  choose_sql_jinja_env_based_on_template_syntax,
38
+ has_sql_templates,
36
39
  )
37
40
 
38
41
 
42
+ def _is_sql_file(file: Path) -> bool:
43
+ return file.name.lower().endswith(".sql")
44
+
45
+
39
46
  class TemplatesProcessor(ArtifactProcessor):
40
47
  """
41
48
  Processor class to perform template expansion on all relevant artifacts (specified in the project definition file).
42
49
  """
43
50
 
51
+ def expand_templates_in_file(self, src: Path, dest: Path) -> None:
52
+ """
53
+ Expand templates in the file.
54
+ """
55
+ if src.is_dir():
56
+ return
57
+
58
+ with self.edit_file(dest) as file:
59
+ if not has_client_side_templates(file.contents) and not (
60
+ _is_sql_file(dest) and has_sql_templates(file.contents)
61
+ ):
62
+ return
63
+
64
+ src_file_name = src.relative_to(self._bundle_ctx.project_root)
65
+ cc.step(f"Expanding templates in {src_file_name}")
66
+ with cc.indented():
67
+ try:
68
+ jinja_env = (
69
+ choose_sql_jinja_env_based_on_template_syntax(
70
+ file.contents, reference_name=src_file_name
71
+ )
72
+ if _is_sql_file(dest)
73
+ else get_client_side_jinja_env()
74
+ )
75
+ expanded_template = jinja_env.from_string(file.contents).render(
76
+ get_cli_context().template_context
77
+ )
78
+
79
+ # For now, we are printing the source file path in the error message
80
+ # instead of the destination file path to make it easier for the user
81
+ # to identify the file that has the error, and edit the correct file.
82
+ except jinja2.TemplateSyntaxError as e:
83
+ raise InvalidTemplateInFileError(src_file_name, e, e.lineno) from e
84
+
85
+ except jinja2.UndefinedError as e:
86
+ raise InvalidTemplateInFileError(src_file_name, e) from e
87
+
88
+ if expanded_template != file.contents:
89
+ file.edited_contents = expanded_template
90
+
44
91
  def process(
45
92
  self,
46
93
  artifact_to_process: PathMapping,
47
94
  processor_mapping: Optional[ProcessorMapping],
48
95
  **kwargs,
49
- ):
96
+ ) -> None:
50
97
  """
51
98
  Process the artifact by executing the template expansion logic on it.
52
99
  """
53
- cc.step(f"Processing artifact {artifact_to_process} with templates processor")
54
100
 
55
101
  bundle_map = BundleMap(
56
102
  project_root=self._bundle_ctx.project_root,
@@ -62,32 +108,4 @@ class TemplatesProcessor(ArtifactProcessor):
62
108
  absolute=True,
63
109
  expand_directories=True,
64
110
  ):
65
- if src.is_dir():
66
- continue
67
- with self.edit_file(dest) as f:
68
- file_name = src.relative_to(self._bundle_ctx.project_root)
69
-
70
- jinja_env = (
71
- choose_sql_jinja_env_based_on_template_syntax(
72
- f.contents, reference_name=file_name
73
- )
74
- if dest.name.lower().endswith(".sql")
75
- else get_client_side_jinja_env()
76
- )
77
-
78
- try:
79
- expanded_template = jinja_env.from_string(f.contents).render(
80
- get_cli_context().template_context
81
- )
82
-
83
- # For now, we are printing the source file path in the error message
84
- # instead of the destination file path to make it easier for the user
85
- # to identify the file that has the error, and edit the correct file.
86
- except jinja2.TemplateSyntaxError as e:
87
- raise InvalidTemplateInFileError(file_name, e, e.lineno) from e
88
-
89
- except jinja2.UndefinedError as e:
90
- raise InvalidTemplateInFileError(file_name, e) from e
91
-
92
- if expanded_template != f.contents:
93
- f.edited_contents = expanded_template
111
+ self.expand_templates_in_file(src, dest)