snowflake-cli 3.10.1__py3-none-any.whl → 3.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. snowflake/cli/__about__.py +1 -1
  2. snowflake/cli/_app/auth/__init__.py +13 -0
  3. snowflake/cli/_app/auth/errors.py +28 -0
  4. snowflake/cli/_app/auth/oidc_providers.py +393 -0
  5. snowflake/cli/_app/cli_app.py +0 -1
  6. snowflake/cli/_app/constants.py +10 -0
  7. snowflake/cli/_app/printing.py +153 -19
  8. snowflake/cli/_app/snow_connector.py +35 -0
  9. snowflake/cli/_plugins/auth/__init__.py +4 -2
  10. snowflake/cli/_plugins/auth/keypair/commands.py +2 -0
  11. snowflake/cli/_plugins/auth/oidc/__init__.py +13 -0
  12. snowflake/cli/_plugins/auth/oidc/commands.py +47 -0
  13. snowflake/cli/_plugins/auth/oidc/manager.py +66 -0
  14. snowflake/cli/_plugins/auth/oidc/plugin_spec.py +30 -0
  15. snowflake/cli/_plugins/connection/commands.py +37 -3
  16. snowflake/cli/_plugins/dbt/commands.py +37 -8
  17. snowflake/cli/_plugins/dbt/manager.py +144 -12
  18. snowflake/cli/_plugins/dcm/commands.py +102 -136
  19. snowflake/cli/_plugins/dcm/manager.py +136 -89
  20. snowflake/cli/_plugins/logs/commands.py +7 -0
  21. snowflake/cli/_plugins/logs/manager.py +21 -1
  22. snowflake/cli/_plugins/nativeapp/sf_sql_facade.py +3 -1
  23. snowflake/cli/_plugins/notebook/notebook_entity.py +2 -0
  24. snowflake/cli/_plugins/notebook/notebook_entity_model.py +8 -1
  25. snowflake/cli/_plugins/object/command_aliases.py +16 -1
  26. snowflake/cli/_plugins/object/commands.py +27 -1
  27. snowflake/cli/_plugins/object/manager.py +12 -1
  28. snowflake/cli/_plugins/snowpark/commands.py +8 -1
  29. snowflake/cli/_plugins/snowpark/common.py +1 -0
  30. snowflake/cli/_plugins/snowpark/package/anaconda_packages.py +29 -5
  31. snowflake/cli/_plugins/snowpark/package_utils.py +44 -3
  32. snowflake/cli/_plugins/spcs/services/manager.py +5 -4
  33. snowflake/cli/_plugins/sql/lexer/types.py +1 -0
  34. snowflake/cli/_plugins/sql/repl.py +100 -26
  35. snowflake/cli/_plugins/sql/repl_commands.py +607 -0
  36. snowflake/cli/_plugins/sql/statement_reader.py +44 -20
  37. snowflake/cli/api/artifacts/bundle_map.py +32 -2
  38. snowflake/cli/api/artifacts/regex_resolver.py +54 -0
  39. snowflake/cli/api/artifacts/upload.py +5 -1
  40. snowflake/cli/api/artifacts/utils.py +12 -1
  41. snowflake/cli/api/cli_global_context.py +7 -0
  42. snowflake/cli/api/commands/decorators.py +7 -0
  43. snowflake/cli/api/commands/flags.py +26 -0
  44. snowflake/cli/api/config.py +24 -0
  45. snowflake/cli/api/connections.py +1 -0
  46. snowflake/cli/api/console/abc.py +13 -2
  47. snowflake/cli/api/console/console.py +20 -0
  48. snowflake/cli/api/constants.py +9 -0
  49. snowflake/cli/api/entities/utils.py +10 -6
  50. snowflake/cli/api/feature_flags.py +1 -0
  51. snowflake/cli/api/identifiers.py +18 -1
  52. snowflake/cli/api/project/schemas/entities/entities.py +0 -6
  53. snowflake/cli/api/rendering/sql_templates.py +2 -0
  54. snowflake/cli/api/utils/dict_utils.py +42 -1
  55. {snowflake_cli-3.10.1.dist-info → snowflake_cli-3.12.0.dist-info}/METADATA +15 -41
  56. {snowflake_cli-3.10.1.dist-info → snowflake_cli-3.12.0.dist-info}/RECORD +59 -52
  57. snowflake/cli/_plugins/dcm/dcm_project_entity_model.py +0 -59
  58. snowflake/cli/_plugins/sql/snowsql_commands.py +0 -331
  59. {snowflake_cli-3.10.1.dist-info → snowflake_cli-3.12.0.dist-info}/WHEEL +0 -0
  60. {snowflake_cli-3.10.1.dist-info → snowflake_cli-3.12.0.dist-info}/entry_points.txt +0 -0
  61. {snowflake_cli-3.10.1.dist-info → snowflake_cli-3.12.0.dist-info}/licenses/LICENSE +0 -0
@@ -22,11 +22,16 @@ from typing import Dict, Optional
22
22
  import snowflake.connector
23
23
  from click.exceptions import ClickException
24
24
  from snowflake.cli import __about__
25
+ from snowflake.cli._app.auth.oidc_providers import OidcProviderTypeWithAuto
25
26
  from snowflake.cli._app.constants import (
27
+ AUTHENTICATOR_WORKLOAD_IDENTITY,
26
28
  INTERNAL_APPLICATION_NAME,
27
29
  PARAM_APPLICATION_NAME,
28
30
  )
29
31
  from snowflake.cli._app.telemetry import command_info
32
+ from snowflake.cli._plugins.auth.oidc.manager import (
33
+ OidcManager,
34
+ )
30
35
  from snowflake.cli.api.config import (
31
36
  get_connection_dict,
32
37
  get_env_value,
@@ -40,6 +45,7 @@ from snowflake.cli.api.feature_flags import FeatureFlag
40
45
  from snowflake.cli.api.secret import SecretType
41
46
  from snowflake.cli.api.secure_path import SecurePath
42
47
  from snowflake.connector import SnowflakeConnection
48
+ from snowflake.connector.auth.workload_identity import ApiFederatedAuthenticationType
43
49
  from snowflake.connector.errors import DatabaseError, ForbiddenError
44
50
 
45
51
  log = logging.getLogger(__name__)
@@ -54,6 +60,7 @@ SUPPORTED_ENV_OVERRIDES = [
54
60
  "user",
55
61
  "password",
56
62
  "authenticator",
63
+ "workload_identity_provider",
57
64
  "private_key_file",
58
65
  "private_key_path",
59
66
  "private_key_raw",
@@ -153,6 +160,14 @@ def connect_to_snowflake(
153
160
  if connection_parameters.get("authenticator") == "username_password_mfa":
154
161
  connection_parameters["client_request_mfa_token"] = True
155
162
 
163
+ # Handle WORKLOAD_IDENTITY authenticator (OIDC authentication)
164
+ if (
165
+ connection_parameters.get("authenticator") == AUTHENTICATOR_WORKLOAD_IDENTITY
166
+ and connection_parameters.get("workload_identity_provider")
167
+ == ApiFederatedAuthenticationType.OIDC.value
168
+ ):
169
+ _maybe_update_oidc_token(connection_parameters)
170
+
156
171
  if enable_diag:
157
172
  connection_parameters["enable_connection_diag"] = enable_diag
158
173
  if diag_log_path:
@@ -335,3 +350,23 @@ def prepare_private_key(
335
350
  encryption_algorithm=NoEncryption(),
336
351
  )
337
352
  )
353
+
354
+
355
+ def _maybe_update_oidc_token(connection_parameters: dict) -> dict:
356
+ """Try to obtain OIDC token automatically."""
357
+ try:
358
+ manager = OidcManager()
359
+ if token := manager.read_token(OidcProviderTypeWithAuto.AUTO):
360
+ log.info("%s token acquired automatically", AUTHENTICATOR_WORKLOAD_IDENTITY)
361
+ connection_parameters.update(
362
+ {
363
+ "token": token,
364
+ }
365
+ )
366
+ except Exception as e:
367
+ log.info(
368
+ "No token found when while %s auto auto-detection: %s",
369
+ AUTHENTICATOR_WORKLOAD_IDENTITY,
370
+ str(e),
371
+ )
372
+ return connection_parameters
@@ -1,11 +1,13 @@
1
1
  from snowflake.cli._plugins.auth.keypair.commands import app as keypair_app
2
+ from snowflake.cli._plugins.auth.oidc.commands import (
3
+ app as oidc_app,
4
+ )
2
5
  from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
3
- from snowflake.cli.api.feature_flags import FeatureFlag
4
6
 
5
7
  app = SnowTyperFactory(
6
8
  name="auth",
7
9
  help="Manages authentication methods.",
8
- is_hidden=lambda: FeatureFlag.ENABLE_AUTH_KEYPAIR.is_disabled(),
9
10
  )
10
11
 
11
12
  app.add_typer(keypair_app)
13
+ app.add_typer(oidc_app)
@@ -4,6 +4,7 @@ import typer
4
4
  from snowflake.cli._plugins.auth.keypair.manager import AuthManager, PublicKeyProperty
5
5
  from snowflake.cli.api.commands.flags import SecretTypeParser
6
6
  from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
7
+ from snowflake.cli.api.feature_flags import FeatureFlag
7
8
  from snowflake.cli.api.output.types import (
8
9
  CollectionResult,
9
10
  CommandResult,
@@ -16,6 +17,7 @@ from snowflake.cli.api.secure_path import SecurePath
16
17
  app = SnowTyperFactory(
17
18
  name="keypair",
18
19
  help="Manages authentication.",
20
+ is_hidden=lambda: FeatureFlag.ENABLE_AUTH_KEYPAIR.is_disabled(),
19
21
  )
20
22
 
21
23
 
@@ -0,0 +1,13 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,47 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import typer
16
+ from snowflake.cli._app.auth.oidc_providers import (
17
+ OidcProviderTypeWithAuto,
18
+ )
19
+ from snowflake.cli._plugins.auth.oidc.manager import OidcManager
20
+ from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
21
+ from snowflake.cli.api.output.types import MessageResult
22
+
23
+ app = SnowTyperFactory(
24
+ name="oidc",
25
+ help="Manages OIDC authentication.",
26
+ )
27
+
28
+
29
+ AutoProviderTypeOption = typer.Option(
30
+ OidcProviderTypeWithAuto.AUTO.value,
31
+ "--type",
32
+ help=f"Type of OIDC provider to use",
33
+ show_default=False,
34
+ )
35
+
36
+
37
+ @app.command("read-token", requires_connection=False)
38
+ def read_token(
39
+ _type: OidcProviderTypeWithAuto = AutoProviderTypeOption,
40
+ **options,
41
+ ):
42
+ """
43
+ Reads OIDC token based on the specified type.
44
+ Use 'auto' to auto-detect available providers.
45
+ """
46
+ result = OidcManager().read_token(provider_type=_type)
47
+ return MessageResult(result)
@@ -0,0 +1,66 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from typing import TypeAlias
17
+
18
+ from snowflake.cli._app.auth.errors import OidcProviderError
19
+ from snowflake.cli._app.auth.oidc_providers import (
20
+ OidcProviderType,
21
+ OidcProviderTypeWithAuto,
22
+ auto_detect_oidc_provider,
23
+ get_active_oidc_provider,
24
+ )
25
+ from snowflake.cli.api.exceptions import CliError
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ Providers: TypeAlias = OidcProviderType | OidcProviderTypeWithAuto
31
+
32
+
33
+ class OidcManager:
34
+ """
35
+ Manages OIDC authentication.
36
+
37
+ This class provides methods to read OIDC configurations for authentication.
38
+ """
39
+
40
+ def read_token(
41
+ self,
42
+ provider_type: Providers = OidcProviderTypeWithAuto.AUTO,
43
+ ) -> str:
44
+ """
45
+ Reads OIDC token based on the specified provider type.
46
+
47
+ Args:
48
+ provider_type: Type of provider to read token from ("auto" for auto-detection)
49
+
50
+ Returns:
51
+ Token string or provider information
52
+
53
+ Raises:
54
+ CliError: If token reading fails
55
+ """
56
+ logger.info("Reading OIDC token with provider type: %s", provider_type)
57
+
58
+ try:
59
+ if provider_type == OidcProviderTypeWithAuto.AUTO:
60
+ provider = auto_detect_oidc_provider()
61
+ else:
62
+ provider = get_active_oidc_provider(provider_type.value)
63
+ return provider.get_token()
64
+ except OidcProviderError as e:
65
+ logger.error("OIDC provider error: %s", str(e))
66
+ raise CliError(str(e))
@@ -0,0 +1,30 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from snowflake.cli._plugins.auth import app
16
+ from snowflake.cli.api.plugins.command import (
17
+ SNOWCLI_ROOT_COMMAND_PATH,
18
+ CommandSpec,
19
+ CommandType,
20
+ plugin_hook_impl,
21
+ )
22
+
23
+
24
+ @plugin_hook_impl
25
+ def command_spec():
26
+ return CommandSpec(
27
+ parent_command_path=SNOWCLI_ROOT_COMMAND_PATH,
28
+ command_type=CommandType.COMMAND_GROUP,
29
+ typer_instance=app.create_instance(),
30
+ )
@@ -53,6 +53,7 @@ from snowflake.cli.api.commands.flags import (
53
53
  TokenFilePathOption,
54
54
  UserOption,
55
55
  WarehouseOption,
56
+ WorkloadIdentityProviderOption,
56
57
  )
57
58
  from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
58
59
  from snowflake.cli.api.config import (
@@ -62,7 +63,9 @@ from snowflake.cli.api.config import (
62
63
  get_all_connections,
63
64
  get_connection_dict,
64
65
  get_default_connection_name,
66
+ remove_connection_from_proper_file,
65
67
  set_config_value,
68
+ unset_config_value,
66
69
  )
67
70
  from snowflake.cli.api.console import cli_console
68
71
  from snowflake.cli.api.constants import ObjectType
@@ -220,6 +223,12 @@ def add(
220
223
  *AuthenticatorOption.param_decls,
221
224
  help="Chosen authenticator, if other than password-based",
222
225
  ),
226
+ workload_identity_provider: Optional[str] = typer.Option(
227
+ None,
228
+ "-W",
229
+ *WorkloadIdentityProviderOption.param_decls,
230
+ help="Workload identity provider type",
231
+ ),
223
232
  private_key_file: Optional[str] = typer.Option(
224
233
  None,
225
234
  "--private-key",
@@ -256,6 +265,7 @@ def add(
256
265
  "port": port,
257
266
  "region": region,
258
267
  "authenticator": authenticator,
268
+ "workload_identity_provider": workload_identity_provider,
259
269
  "private_key_file": private_key_file,
260
270
  "token_file_path": token_file_path,
261
271
  }
@@ -317,6 +327,30 @@ def add(
317
327
  )
318
328
 
319
329
 
330
+ @app.command(requires_connection=False)
331
+ def remove(
332
+ connection_name: str = typer.Argument(
333
+ help="Name of the connection to remove.",
334
+ show_default=False,
335
+ ),
336
+ **options,
337
+ ):
338
+ """Removes a connection from configuration file."""
339
+ if not connection_exists(connection_name):
340
+ raise UsageError(f"Connection {connection_name} does not exist.")
341
+
342
+ is_default = get_default_connection_name() == connection_name
343
+ if is_default:
344
+ unset_config_value(path=["default_connection_name"])
345
+
346
+ connections_file = remove_connection_from_proper_file(connection_name)
347
+
348
+ return MessageResult(
349
+ f"Removed connection {connection_name} from {connections_file}."
350
+ f"{' It was the default connection, so default connection is now unset.' if is_default else ''}"
351
+ )
352
+
353
+
320
354
  @app.command(requires_connection=True)
321
355
  def test(
322
356
  **options,
@@ -355,9 +389,9 @@ def test(
355
389
  "Host": conn.host,
356
390
  "Account": conn.account,
357
391
  "User": conn.user,
358
- "Role": f'{conn.role or "not set"}',
359
- "Database": f'{conn.database or "not set"}',
360
- "Warehouse": f'{conn.warehouse or "not set"}',
392
+ "Role": f"{conn.role or 'not set'}",
393
+ "Database": f"{conn.database or 'not set'}",
394
+ "Warehouse": f"{conn.warehouse or 'not set'}",
361
395
  }
362
396
 
363
397
  if conn_ctx.enable_diag:
@@ -19,7 +19,6 @@ from typing import Optional
19
19
 
20
20
  import typer
21
21
  from click import types
22
- from rich.progress import Progress, SpinnerColumn, TextColumn
23
22
  from snowflake.cli._plugins.dbt.constants import (
24
23
  DBT_COMMANDS,
25
24
  OUTPUT_COLUMN_NAME,
@@ -31,7 +30,9 @@ from snowflake.cli._plugins.object.command_aliases import add_object_command_ali
31
30
  from snowflake.cli._plugins.object.commands import scope_option
32
31
  from snowflake.cli.api.commands.decorators import global_options_with_connection
33
32
  from snowflake.cli.api.commands.flags import identifier_argument, like_option
33
+ from snowflake.cli.api.commands.overrideable_parameter import OverrideableOption
34
34
  from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
35
+ from snowflake.cli.api.console.console import cli_console
35
36
  from snowflake.cli.api.constants import ObjectType
36
37
  from snowflake.cli.api.exceptions import CliError
37
38
  from snowflake.cli.api.feature_flags import FeatureFlag
@@ -59,6 +60,16 @@ DBTNameArgument = identifier_argument(sf_object="DBT Project", example="my_pipel
59
60
  DBTNameOrCommandArgument = identifier_argument(
60
61
  sf_object="DBT Project", example="my_pipeline", click_type=types.StringParamType()
61
62
  )
63
+ DefaultTargetOption = OverrideableOption(
64
+ None,
65
+ "--default-target",
66
+ mutually_exclusive=["unset_default_target"],
67
+ )
68
+ UnsetDefaultTargetOption = OverrideableOption(
69
+ False,
70
+ "--unset-default-target",
71
+ mutually_exclusive=["default_target"],
72
+ )
62
73
 
63
74
  add_object_command_aliases(
64
75
  app=app,
@@ -92,6 +103,21 @@ def deploy_dbt(
92
103
  False,
93
104
  help="Overwrites conflicting files in the project, if any.",
94
105
  ),
106
+ default_target: Optional[str] = DefaultTargetOption(
107
+ help="Default target for the dbt project. Mutually exclusive with --unset-default-target.",
108
+ hidden=FeatureFlag.ENABLE_DBT_GA_FEATURES.is_disabled(),
109
+ ),
110
+ unset_default_target: Optional[bool] = UnsetDefaultTargetOption(
111
+ help="Unset the default target for the dbt project. Mutually exclusive with --default-target.",
112
+ hidden=FeatureFlag.ENABLE_DBT_GA_FEATURES.is_disabled(),
113
+ ),
114
+ external_access_integrations: Optional[list[str]] = typer.Option(
115
+ None,
116
+ "--external-access-integration",
117
+ show_default=False,
118
+ help="External access integration to be used by the dbt object.",
119
+ hidden=FeatureFlag.ENABLE_DBT_GA_FEATURES.is_disabled(),
120
+ ),
95
121
  **options,
96
122
  ) -> CommandResult:
97
123
  """
@@ -99,6 +125,11 @@ def deploy_dbt(
99
125
  provided; or create a new one if it doesn't exist; or update files and
100
126
  create a new version if it exists.
101
127
  """
128
+ if FeatureFlag.ENABLE_DBT_GA_FEATURES.is_disabled():
129
+ default_target = None
130
+ unset_default_target = False
131
+ external_access_integrations = None
132
+
102
133
  project_path = SecurePath(source) if source is not None else SecurePath.cwd()
103
134
  profiles_dir_path = SecurePath(profiles_dir) if profiles_dir else project_path
104
135
  return QueryResult(
@@ -107,6 +138,9 @@ def deploy_dbt(
107
138
  project_path.resolve(),
108
139
  profiles_dir_path.resolve(),
109
140
  force=force,
141
+ default_target=default_target,
142
+ unset_default_target=unset_default_target,
143
+ external_access_integrations=external_access_integrations,
110
144
  )
111
145
  )
112
146
 
@@ -161,13 +195,8 @@ for cmd in DBT_COMMANDS:
161
195
  f"Command submitted. You can check the result with `snow sql -q \"select execution_status from table(information_schema.query_history_by_user()) where query_id in ('{result.sfqid}');\"`"
162
196
  )
163
197
 
164
- with Progress(
165
- SpinnerColumn(),
166
- TextColumn("[progress.description]{task.description}"),
167
- transient=True,
168
- ) as progress:
169
- progress.add_task(description=f"Executing 'dbt {dbt_command}'", total=None)
170
-
198
+ with cli_console.spinner() as spinner:
199
+ spinner.add_task(description=f"Executing 'dbt {dbt_command}'", total=None)
171
200
  result = dbt_manager.execute(*execute_args)
172
201
 
173
202
  try:
@@ -17,6 +17,7 @@ from __future__ import annotations
17
17
  from collections import defaultdict
18
18
  from pathlib import Path
19
19
  from tempfile import TemporaryDirectory
20
+ from typing import List, Optional, TypedDict
20
21
 
21
22
  import yaml
22
23
  from snowflake.cli._plugins.dbt.constants import PROFILES_FILENAME
@@ -26,10 +27,14 @@ from snowflake.cli.api.console import cli_console
26
27
  from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB, ObjectType
27
28
  from snowflake.cli.api.exceptions import CliError
28
29
  from snowflake.cli.api.identifiers import FQN
29
- from snowflake.cli.api.project.util import unquote_identifier
30
30
  from snowflake.cli.api.secure_path import SecurePath
31
31
  from snowflake.cli.api.sql_execution import SqlExecutionMixin
32
32
  from snowflake.connector.cursor import SnowflakeCursor
33
+ from snowflake.connector.errors import ProgrammingError
34
+
35
+
36
+ class DBTObjectEditableAttributes(TypedDict):
37
+ default_target: Optional[str]
33
38
 
34
39
 
35
40
  class DBTManager(SqlExecutionMixin):
@@ -43,12 +48,44 @@ class DBTManager(SqlExecutionMixin):
43
48
  object_type=ObjectType.DBT_PROJECT.value.cli_name, fqn=name
44
49
  )
45
50
 
51
+ @staticmethod
52
+ def describe(name: FQN) -> SnowflakeCursor:
53
+ return ObjectManager().describe(
54
+ object_type=ObjectType.DBT_PROJECT.value.cli_name, fqn=name
55
+ )
56
+
57
+ @staticmethod
58
+ def get_dbt_object_attributes(name: FQN) -> Optional[DBTObjectEditableAttributes]:
59
+ """Get editable attributes of an existing DBT project, or None if it doesn't exist."""
60
+ try:
61
+ cursor = DBTManager().describe(name)
62
+ except ProgrammingError as exc:
63
+ if "DBT PROJECT" in exc.msg and "does not exist" in exc.msg:
64
+ return None
65
+ raise exc
66
+
67
+ rows = list(cursor)
68
+ if not rows:
69
+ return None
70
+
71
+ row = rows[0]
72
+ # Convert row to dict using column names
73
+ columns = [desc[0] for desc in cursor.description]
74
+ row_dict = dict(zip(columns, row))
75
+
76
+ return DBTObjectEditableAttributes(
77
+ default_target=row_dict.get("default_target")
78
+ )
79
+
46
80
  def deploy(
47
81
  self,
48
82
  fqn: FQN,
49
83
  path: SecurePath,
50
84
  profiles_path: SecurePath,
51
85
  force: bool,
86
+ default_target: Optional[str] = None,
87
+ unset_default_target: bool = False,
88
+ external_access_integrations: Optional[List[str]] = None,
52
89
  ) -> SnowflakeCursor:
53
90
  dbt_project_path = path / "dbt_project.yml"
54
91
  if not dbt_project_path.exists():
@@ -63,14 +100,13 @@ class DBTManager(SqlExecutionMixin):
63
100
  except KeyError:
64
101
  raise CliError("`profile` is not defined in dbt_project.yml")
65
102
 
66
- self._validate_profiles(profiles_path, profile)
103
+ self._validate_profiles(profiles_path, profile, default_target)
67
104
 
68
105
  with cli_console.phase("Creating temporary stage"):
69
106
  stage_manager = StageManager()
70
- unquoted_name = unquote_identifier(fqn.name)
71
- stage_fqn = FQN.from_string(f"DBT_{unquoted_name}_STAGE").using_context()
72
- stage_name = stage_manager.get_standard_stage_prefix(stage_fqn)
107
+ stage_fqn = FQN.from_resource(ObjectType.DBT_PROJECT, fqn, "STAGE")
73
108
  stage_manager.create(stage_fqn, temporary=True)
109
+ stage_name = stage_manager.get_standard_stage_prefix(stage_fqn)
74
110
 
75
111
  with cli_console.phase("Copying project files to stage"):
76
112
  with TemporaryDirectory() as tmp:
@@ -88,22 +124,109 @@ class DBTManager(SqlExecutionMixin):
88
124
 
89
125
  with cli_console.phase("Creating DBT project"):
90
126
  if force is True:
91
- query = f"CREATE OR REPLACE DBT PROJECT {fqn}"
92
- elif self.exists(name=fqn):
93
- query = f"ALTER DBT PROJECT {fqn} ADD VERSION"
127
+ return self._deploy_create_or_replace(
128
+ fqn, stage_name, default_target, external_access_integrations
129
+ )
94
130
  else:
95
- query = f"CREATE DBT PROJECT {fqn}"
96
- query += f"\nFROM {stage_name}"
97
- return self.execute_query(query)
131
+ dbt_object_attributes = self.get_dbt_object_attributes(fqn)
132
+ if dbt_object_attributes is not None:
133
+ return self._deploy_alter(
134
+ fqn,
135
+ stage_name,
136
+ dbt_object_attributes,
137
+ default_target,
138
+ unset_default_target,
139
+ external_access_integrations,
140
+ )
141
+ else:
142
+ return self._deploy_create(
143
+ fqn, stage_name, default_target, external_access_integrations
144
+ )
145
+
146
+ def _deploy_alter(
147
+ self,
148
+ fqn: FQN,
149
+ stage_name: str,
150
+ dbt_object_attributes: DBTObjectEditableAttributes,
151
+ default_target: Optional[str],
152
+ unset_default_target: bool,
153
+ external_access_integrations: Optional[List[str]],
154
+ ) -> SnowflakeCursor:
155
+ query = f"ALTER DBT PROJECT {fqn} ADD VERSION"
156
+ query += f"\nFROM {stage_name}"
157
+ query = self._handle_external_access_integrations_query(
158
+ query, external_access_integrations
159
+ )
160
+ result = self.execute_query(query)
161
+ current_default_target = dbt_object_attributes.get("default_target")
162
+ if unset_default_target and current_default_target is not None:
163
+ unset_query = f"ALTER DBT PROJECT {fqn} UNSET DEFAULT_TARGET"
164
+ self.execute_query(unset_query)
165
+ elif default_target and (
166
+ current_default_target is None
167
+ or current_default_target.lower() != default_target.lower()
168
+ ):
169
+ set_default_query = (
170
+ f"ALTER DBT PROJECT {fqn} SET DEFAULT_TARGET='{default_target}'"
171
+ )
172
+ self.execute_query(set_default_query)
173
+ return result
174
+
175
+ def _deploy_create(
176
+ self,
177
+ fqn: FQN,
178
+ stage_name: str,
179
+ default_target: Optional[str],
180
+ external_access_integrations: Optional[List[str]],
181
+ ) -> SnowflakeCursor:
182
+ # Project doesn't exist - create new one
183
+ query = f"CREATE DBT PROJECT {fqn}"
184
+ query += f"\nFROM {stage_name}"
185
+ if default_target:
186
+ query += f" DEFAULT_TARGET='{default_target}'"
187
+ query = self._handle_external_access_integrations_query(
188
+ query, external_access_integrations
189
+ )
190
+ return self.execute_query(query)
191
+
192
+ @staticmethod
193
+ def _handle_external_access_integrations_query(
194
+ query: str, external_access_integrations: Optional[List[str]]
195
+ ) -> str:
196
+ if external_access_integrations:
197
+ integrations_str = ", ".join(external_access_integrations)
198
+ query += f"\nEXTERNAL_ACCESS_INTEGRATIONS = ({integrations_str})"
199
+ return query
200
+
201
+ def _deploy_create_or_replace(
202
+ self,
203
+ fqn: FQN,
204
+ stage_name: str,
205
+ default_target: Optional[str],
206
+ external_access_integrations: Optional[List[str]],
207
+ ) -> SnowflakeCursor:
208
+ query = f"CREATE OR REPLACE DBT PROJECT {fqn}"
209
+ query += f"\nFROM {stage_name}"
210
+ if default_target:
211
+ query += f" DEFAULT_TARGET='{default_target}'"
212
+ query = self._handle_external_access_integrations_query(
213
+ query, external_access_integrations
214
+ )
215
+ return self.execute_query(query)
98
216
 
99
217
  @staticmethod
100
- def _validate_profiles(profiles_path: SecurePath, target_profile: str) -> None:
218
+ def _validate_profiles(
219
+ profiles_path: SecurePath,
220
+ target_profile: str,
221
+ default_target: str | None = None,
222
+ ) -> None:
101
223
  """
102
224
  Validates that:
103
225
  * profiles.yml exists
104
226
  * contain profile specified in dbt_project.yml
105
227
  * no other profiles are defined there
106
228
  * does not contain any confidential data like passwords
229
+ * default_target (if specified) exists in the profile's outputs
107
230
  """
108
231
  profiles_file = profiles_path / PROFILES_FILENAME
109
232
  if not profiles_file.exists():
@@ -154,6 +277,15 @@ class DBTManager(SqlExecutionMixin):
154
277
  f"Value for type field is invalid. Should be set to `snowflake` in target {target_name}"
155
278
  )
156
279
 
280
+ if default_target is not None:
281
+ available_targets = set(profiles[target_profile]["outputs"].keys())
282
+ if default_target not in available_targets:
283
+ available_targets_str = ", ".join(sorted(available_targets))
284
+ errors["default_target"].append(
285
+ f"Default target '{default_target}' is not defined in profile '{target_profile}'. "
286
+ f"Available targets: {available_targets_str}"
287
+ )
288
+
157
289
  if errors:
158
290
  message = f"Found following errors in {PROFILES_FILENAME}. Please fix them before proceeding:"
159
291
  for target, issues in errors.items():