snowflake-cli 3.10.1__py3-none-any.whl → 3.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. snowflake/cli/__about__.py +1 -1
  2. snowflake/cli/_app/auth/__init__.py +13 -0
  3. snowflake/cli/_app/auth/errors.py +28 -0
  4. snowflake/cli/_app/auth/oidc_providers.py +393 -0
  5. snowflake/cli/_app/constants.py +10 -0
  6. snowflake/cli/_app/snow_connector.py +35 -0
  7. snowflake/cli/_plugins/auth/__init__.py +4 -2
  8. snowflake/cli/_plugins/auth/keypair/commands.py +2 -0
  9. snowflake/cli/_plugins/auth/oidc/__init__.py +13 -0
  10. snowflake/cli/_plugins/auth/oidc/commands.py +47 -0
  11. snowflake/cli/_plugins/auth/oidc/manager.py +66 -0
  12. snowflake/cli/_plugins/auth/oidc/plugin_spec.py +30 -0
  13. snowflake/cli/_plugins/connection/commands.py +37 -3
  14. snowflake/cli/_plugins/dbt/manager.py +1 -3
  15. snowflake/cli/_plugins/dcm/commands.py +79 -88
  16. snowflake/cli/_plugins/dcm/manager.py +17 -57
  17. snowflake/cli/_plugins/notebook/notebook_entity.py +2 -0
  18. snowflake/cli/_plugins/notebook/notebook_entity_model.py +8 -1
  19. snowflake/cli/_plugins/object/command_aliases.py +16 -1
  20. snowflake/cli/_plugins/object/commands.py +27 -1
  21. snowflake/cli/_plugins/object/manager.py +12 -1
  22. snowflake/cli/_plugins/snowpark/commands.py +8 -1
  23. snowflake/cli/api/commands/decorators.py +7 -0
  24. snowflake/cli/api/commands/flags.py +26 -0
  25. snowflake/cli/api/config.py +24 -0
  26. snowflake/cli/api/connections.py +1 -0
  27. snowflake/cli/api/utils/dict_utils.py +42 -1
  28. {snowflake_cli-3.10.1.dist-info → snowflake_cli-3.11.0.dist-info}/METADATA +12 -38
  29. {snowflake_cli-3.10.1.dist-info → snowflake_cli-3.11.0.dist-info}/RECORD +32 -25
  30. {snowflake_cli-3.10.1.dist-info → snowflake_cli-3.11.0.dist-info}/WHEEL +0 -0
  31. {snowflake_cli-3.10.1.dist-info → snowflake_cli-3.11.0.dist-info}/entry_points.txt +0 -0
  32. {snowflake_cli-3.10.1.dist-info → snowflake_cli-3.11.0.dist-info}/licenses/LICENSE +0 -0
@@ -16,7 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  from enum import Enum, unique
18
18
 
19
- VERSION = "3.10.1"
19
+ VERSION = "3.11.0"
20
20
 
21
21
 
22
22
  @unique
@@ -0,0 +1,13 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,28 @@
1
+ class OidcProviderError(Exception):
2
+ """Base exception for OIDC provider related errors."""
3
+
4
+ ...
5
+
6
+
7
+ class OidcProviderNotFoundError(OidcProviderError):
8
+ """Exception raised when requested OIDC provider is not found or unknown."""
9
+
10
+ ...
11
+
12
+
13
+ class OidcProviderUnavailableError(OidcProviderError):
14
+ """Exception raised when OIDC provider is not available in current environment."""
15
+
16
+ ...
17
+
18
+
19
+ class OidcProviderAutoDetectionError(OidcProviderError):
20
+ """Exception raised when auto-detection of OIDC provider fails."""
21
+
22
+ ...
23
+
24
+
25
+ class OidcTokenRetrievalError(OidcProviderError):
26
+ """Exception raised when OIDC token cannot be retrieved."""
27
+
28
+ ...
@@ -0,0 +1,393 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ import inspect
17
+ import logging
18
+ import os
19
+ from abc import ABC, abstractmethod
20
+ from enum import Enum
21
+ from typing import Dict, List, Literal, Optional, Type
22
+
23
+ import id as oidc_id
24
+ from snowflake.cli._app.auth.errors import (
25
+ OidcProviderAutoDetectionError,
26
+ OidcProviderNotFoundError,
27
+ OidcProviderUnavailableError,
28
+ OidcTokenRetrievalError,
29
+ )
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ ACTIONS_ID_TOKEN_REQUEST_URL_ENV: Literal[
35
+ "ACTIONS_ID_TOKEN_REQUEST_URL"
36
+ ] = "ACTIONS_ID_TOKEN_REQUEST_URL"
37
+ GITHUB_ACTIONS_ENV: Literal["GITHUB_ACTIONS"] = "GITHUB_ACTIONS"
38
+ SNOWFLAKE_AUDIENCE_ENV: Literal["SNOWFLAKE_AUDIENCE"] = "SNOWFLAKE_AUDIENCE"
39
+
40
+
41
+ class OidcProviderType(Enum):
42
+ """Enum for OIDC provider types."""
43
+
44
+ GITHUB = "github"
45
+
46
+
47
+ class OidcProviderTypeWithAuto(Enum):
48
+ """Extended version of OidcProviderType with AUTO."""
49
+
50
+ AUTO = "auto"
51
+ GITHUB = "github"
52
+
53
+
54
+ class OidcTokenProvider(ABC):
55
+ """
56
+ Abstract base class for OIDC token providers.
57
+ Each CI environment should implement this interface.
58
+ """
59
+
60
+ @property
61
+ @abstractmethod
62
+ def provider_name(self) -> str:
63
+ """
64
+ Returns the name of the CI provider (e.g., 'github', 'gitlab', 'azure-devops').
65
+ """
66
+ pass
67
+
68
+ @property
69
+ @abstractmethod
70
+ def is_available(self) -> bool:
71
+ """
72
+ Checks if this provider is available in the current environment.
73
+ Should return True if the provider can detect credentials in the current context.
74
+ """
75
+ pass
76
+
77
+ @property
78
+ @abstractmethod
79
+ def issuer(self) -> str:
80
+ """
81
+ Returns the OIDC issuer URL for this provider.
82
+
83
+ Returns:
84
+ The OIDC issuer URL
85
+ """
86
+ pass
87
+
88
+ @abstractmethod
89
+ def get_token(self) -> str:
90
+ """
91
+ Retrieves the OIDC token from the CI environment.
92
+
93
+ Returns:
94
+ The OIDC token string
95
+
96
+ Raises:
97
+ OidcProviderError: If token cannot be retrieved
98
+ """
99
+ pass
100
+
101
+
102
+ class GitHubOidcProvider(OidcTokenProvider):
103
+ """
104
+ OIDC token provider for GitHub Actions.
105
+ """
106
+
107
+ @property
108
+ def _is_ci(self):
109
+ logger.debug("Checking if GitHub Actions environment is available")
110
+
111
+ # Check if we're in a GitHub Actions environment
112
+ github_actions_env = os.getenv(GITHUB_ACTIONS_ENV)
113
+ logger.debug(
114
+ "%s environment variable: %s",
115
+ GITHUB_ACTIONS_ENV,
116
+ github_actions_env,
117
+ )
118
+
119
+ is_github_actions = github_actions_env == "true"
120
+ logger.debug("Running in GitHub Actions: %s", is_github_actions)
121
+ return is_github_actions
122
+
123
+ @property
124
+ def audience(self) -> str:
125
+ """
126
+ Returns the audience URL for GitHub OIDC.
127
+
128
+ Returns:
129
+ The audience URL, defaults to 'snowflakecomputing.com' if SNOWFLAKE_AUDIENCE environment variable is not set
130
+ """
131
+ return os.getenv(SNOWFLAKE_AUDIENCE_ENV, "snowflakecomputing.com")
132
+
133
+ @property
134
+ def issuer(self) -> str:
135
+ """
136
+ Returns the GitHub OIDC issuer URL.
137
+
138
+ Returns:
139
+ The GitHub OIDC issuer URL from ACTIONS_ID_TOKEN_REQUEST_URL environment variable,
140
+ or the default GitHub issuer URL if the environment variable is not set
141
+ """
142
+ issuer_url = os.getenv(ACTIONS_ID_TOKEN_REQUEST_URL_ENV)
143
+ if not issuer_url and self._is_ci:
144
+ raise OidcTokenRetrievalError(
145
+ "%s environment variable is not set. "
146
+ "This variable is required for Github Actions OIDC authentication"
147
+ % ACTIONS_ID_TOKEN_REQUEST_URL_ENV
148
+ )
149
+ return issuer_url or "https://token.actions.githubusercontent.com"
150
+
151
+ @property
152
+ def provider_name(self) -> str:
153
+ return OidcProviderType.GITHUB.value
154
+
155
+ @property
156
+ def is_available(self) -> bool:
157
+ """
158
+ Checks if GitHub Actions environment is available.
159
+ """
160
+ return self._is_ci
161
+
162
+ def get_token(self) -> str:
163
+ """
164
+ Retrieves the OIDC token from GitHub Actions.
165
+ """
166
+ logger.debug("Retrieving OIDC token from GitHub Actions")
167
+
168
+ try:
169
+ logger.debug("Detecting OIDC credentials for token retrieval")
170
+ # Use configurable audience for workload identity
171
+ token = oidc_id.detect_credential(self.audience)
172
+ if not token:
173
+ logger.error("No OIDC credentials detected")
174
+ raise OidcTokenRetrievalError(
175
+ "No OIDC credentials detected. This command should be run in a GitHub Actions environment."
176
+ )
177
+
178
+ logger.info("Successfully retrieved OIDC token")
179
+ return token
180
+ except Exception as e:
181
+ logger.error("Failed to detect OIDC credentials: %s", str(e))
182
+ raise OidcTokenRetrievalError(
183
+ "Failed to detect OIDC credentials: %s" % str(e)
184
+ )
185
+
186
+
187
+ class OidcProviderRegistry:
188
+ """
189
+ Registry for managing OIDC token providers.
190
+ Handles registration, storage, and retrieval of providers.
191
+ """
192
+
193
+ def __init__(self) -> None:
194
+ self._providers: Dict[str, Type[OidcTokenProvider]] = {}
195
+ self._auto_discover_providers()
196
+
197
+ def _auto_discover_providers(self) -> None:
198
+ """
199
+ Auto-discovers all OIDC token providers in the current module.
200
+ """
201
+ logger.debug("Auto-discovering OIDC token providers")
202
+ current_module = importlib.import_module(__name__)
203
+
204
+ for name, obj in inspect.getmembers(current_module):
205
+ if (
206
+ inspect.isclass(obj)
207
+ and issubclass(obj, OidcTokenProvider)
208
+ and obj != OidcTokenProvider
209
+ ):
210
+ provider_instance = obj()
211
+ provider_name = provider_instance.provider_name
212
+ logger.debug("Discovered OIDC provider: %s (%s)", provider_name, name)
213
+ self._providers[provider_name] = obj
214
+
215
+ logger.info(
216
+ "Auto-discovered %d OIDC provider(s): %s",
217
+ len(self._providers),
218
+ list(self._providers.keys()),
219
+ )
220
+
221
+ def register_provider(self, provider_class: Type[OidcTokenProvider]) -> None:
222
+ """
223
+ Manually register a provider class.
224
+ """
225
+ provider_instance = provider_class()
226
+ self._providers[provider_instance.provider_name] = provider_class
227
+
228
+ def get_provider(self, provider_name: str) -> Optional[OidcTokenProvider]:
229
+ """
230
+ Get a specific provider by name.
231
+ """
232
+ provider_class = self._providers.get(provider_name)
233
+ if provider_class:
234
+ return provider_class()
235
+ return None
236
+
237
+ def get_provider_class(
238
+ self, provider_name: str
239
+ ) -> Optional[Type[OidcTokenProvider]]:
240
+ """
241
+ Get a specific provider class by name.
242
+ """
243
+ return self._providers.get(provider_name)
244
+
245
+ @property
246
+ def provider_names(self) -> List[str]:
247
+ """
248
+ List all registered provider names.
249
+ """
250
+ return list(self._providers.keys())
251
+
252
+ @property
253
+ def all_providers(self) -> List[OidcTokenProvider]:
254
+ """
255
+ Get instances of all registered providers.
256
+ """
257
+ return [provider_class() for provider_class in self._providers.values()]
258
+
259
+
260
+ # Global registry instance
261
+ _registry = OidcProviderRegistry()
262
+
263
+
264
+ def get_oidc_provider(provider_name: str) -> OidcTokenProvider:
265
+ """
266
+ Get a specific OIDC provider by name without checking availability.
267
+
268
+ Args:
269
+ provider_name: Name of the provider to get
270
+
271
+ Returns:
272
+ The requested OIDC provider instance
273
+
274
+ Raises:
275
+ OidcProviderNotFoundError: If provider is unknown
276
+ """
277
+ provider = _registry.get_provider(provider_name)
278
+
279
+ if not provider:
280
+ providers_list = ", ".join(_registry.provider_names)
281
+ raise OidcProviderNotFoundError(
282
+ "Unknown provider '%s'. Available providers: %s"
283
+ % (
284
+ provider_name,
285
+ providers_list,
286
+ )
287
+ )
288
+
289
+ return provider
290
+
291
+
292
+ def get_active_oidc_provider(provider_name: str) -> OidcTokenProvider:
293
+ """
294
+ Get a specific OIDC provider by name and ensure it's available.
295
+
296
+ Args:
297
+ provider_name: Name of the provider to get
298
+
299
+ Returns:
300
+ The requested OIDC provider instance
301
+
302
+ Raises:
303
+ OidcProviderNotFoundError: If provider is unknown
304
+ OidcProviderUnavailableError: If provider is not available
305
+ """
306
+ provider = get_oidc_provider(provider_name)
307
+
308
+ if not provider.is_available:
309
+ raise OidcProviderUnavailableError(
310
+ "Provider '%s' is not available in the current environment." % provider_name
311
+ )
312
+
313
+ return provider
314
+
315
+
316
+ def get_oidc_provider_class(provider_name: str) -> Type[OidcTokenProvider]:
317
+ """
318
+ Get a specific OIDC provider class by name.
319
+
320
+ Args:
321
+ provider_name: Name of the provider to get
322
+
323
+ Returns:
324
+ The requested OIDC provider class
325
+
326
+ Raises:
327
+ OidcProviderNotFoundError: If provider is unknown
328
+ """
329
+ provider_class = _registry.get_provider_class(provider_name)
330
+
331
+ if not provider_class:
332
+ providers_list = ", ".join(_registry.provider_names)
333
+ raise OidcProviderNotFoundError(
334
+ "Unknown provider '%s'. Available providers: %s"
335
+ % (
336
+ provider_name,
337
+ providers_list,
338
+ )
339
+ )
340
+
341
+ return provider_class
342
+
343
+
344
+ def auto_detect_oidc_provider() -> OidcTokenProvider:
345
+ """
346
+ Auto-detect a single available OIDC provider in the current environment.
347
+
348
+ Returns:
349
+ The single available OIDC provider
350
+
351
+ Raises:
352
+ OidcProviderAutoDetectionError: If no providers are available or multiple providers are available
353
+ """
354
+ available = [
355
+ provider for provider in _registry.all_providers if provider.is_available
356
+ ]
357
+ available_names = [p.provider_name for p in available]
358
+
359
+ all_providers = _registry.provider_names
360
+ match (len(available), all_providers):
361
+ case (1, _):
362
+ # Happy path - single provider found
363
+ logger.info("Found 1 available provider: %s", available_names[0])
364
+ return available[0]
365
+ case (0, providers) if providers:
366
+ # No providers available but some are registered
367
+ providers_list = ", ".join(providers)
368
+ msg = (
369
+ "No OIDC provider detected in current environment. "
370
+ "Available providers: %s. "
371
+ "Use --type <provider> to specify a provider explicitly."
372
+ ) % providers_list
373
+ logger.info(msg)
374
+ raise OidcProviderAutoDetectionError(msg)
375
+ case (0, _):
376
+ # No providers available and none are registered
377
+ msg = "No OIDC providers are registered."
378
+ logger.info(msg)
379
+ raise OidcProviderAutoDetectionError(msg)
380
+ case _:
381
+ # Multiple providers available - raise error
382
+ providers_list = ", ".join(available_names)
383
+ msg = (
384
+ "Multiple OIDC providers detected: %s. "
385
+ "Please specify which provider to use with --type <provider>."
386
+ ) % providers_list
387
+ logger.info(msg)
388
+ raise OidcProviderAutoDetectionError(msg)
389
+
390
+ # This line should never be reached, but helps mypy understand all paths are covered
391
+ raise OidcProviderAutoDetectionError(
392
+ "Unexpected state in auto_detect_oidc_provider"
393
+ )
@@ -21,3 +21,13 @@ PARAM_APPLICATION_NAME: Literal["snowcli"] = "snowcli"
21
21
  # This is also defined on server side. Changing this parameter would require
22
22
  # a change in https://github.com/snowflakedb/snowflake
23
23
  INTERNAL_APPLICATION_NAME: Literal["SNOWFLAKE_CLI"] = "SNOWFLAKE_CLI"
24
+
25
+ # Authenticator types
26
+ AUTHENTICATOR_WORKLOAD_IDENTITY: Literal["WORKLOAD_IDENTITY"] = "WORKLOAD_IDENTITY"
27
+ AUTHENTICATOR_SNOWFLAKE_JWT: Literal["SNOWFLAKE_JWT"] = "SNOWFLAKE_JWT"
28
+ AUTHENTICATOR_USERNAME_PASSWORD_MFA: Literal[
29
+ "username_password_mfa"
30
+ ] = "username_password_mfa"
31
+ AUTHENTICATOR_OAUTH_AUTHORIZATION_CODE: Literal[
32
+ "OAUTH_AUTHORIZATION_CODE"
33
+ ] = "OAUTH_AUTHORIZATION_CODE"
@@ -22,11 +22,16 @@ from typing import Dict, Optional
22
22
  import snowflake.connector
23
23
  from click.exceptions import ClickException
24
24
  from snowflake.cli import __about__
25
+ from snowflake.cli._app.auth.oidc_providers import OidcProviderTypeWithAuto
25
26
  from snowflake.cli._app.constants import (
27
+ AUTHENTICATOR_WORKLOAD_IDENTITY,
26
28
  INTERNAL_APPLICATION_NAME,
27
29
  PARAM_APPLICATION_NAME,
28
30
  )
29
31
  from snowflake.cli._app.telemetry import command_info
32
+ from snowflake.cli._plugins.auth.oidc.manager import (
33
+ OidcManager,
34
+ )
30
35
  from snowflake.cli.api.config import (
31
36
  get_connection_dict,
32
37
  get_env_value,
@@ -40,6 +45,7 @@ from snowflake.cli.api.feature_flags import FeatureFlag
40
45
  from snowflake.cli.api.secret import SecretType
41
46
  from snowflake.cli.api.secure_path import SecurePath
42
47
  from snowflake.connector import SnowflakeConnection
48
+ from snowflake.connector.auth.workload_identity import ApiFederatedAuthenticationType
43
49
  from snowflake.connector.errors import DatabaseError, ForbiddenError
44
50
 
45
51
  log = logging.getLogger(__name__)
@@ -54,6 +60,7 @@ SUPPORTED_ENV_OVERRIDES = [
54
60
  "user",
55
61
  "password",
56
62
  "authenticator",
63
+ "workload_identity_provider",
57
64
  "private_key_file",
58
65
  "private_key_path",
59
66
  "private_key_raw",
@@ -153,6 +160,14 @@ def connect_to_snowflake(
153
160
  if connection_parameters.get("authenticator") == "username_password_mfa":
154
161
  connection_parameters["client_request_mfa_token"] = True
155
162
 
163
+ # Handle WORKLOAD_IDENTITY authenticator (OIDC authentication)
164
+ if (
165
+ connection_parameters.get("authenticator") == AUTHENTICATOR_WORKLOAD_IDENTITY
166
+ and connection_parameters.get("workload_identity_provider")
167
+ == ApiFederatedAuthenticationType.OIDC.value
168
+ ):
169
+ _maybe_update_oidc_token(connection_parameters)
170
+
156
171
  if enable_diag:
157
172
  connection_parameters["enable_connection_diag"] = enable_diag
158
173
  if diag_log_path:
@@ -335,3 +350,23 @@ def prepare_private_key(
335
350
  encryption_algorithm=NoEncryption(),
336
351
  )
337
352
  )
353
+
354
+
355
+ def _maybe_update_oidc_token(connection_parameters: dict) -> dict:
356
+ """Try to obtain OIDC token automatically."""
357
+ try:
358
+ manager = OidcManager()
359
+ if token := manager.read_token(OidcProviderTypeWithAuto.AUTO):
360
+ log.info("%s token acquired automatically", AUTHENTICATOR_WORKLOAD_IDENTITY)
361
+ connection_parameters.update(
362
+ {
363
+ "token": token,
364
+ }
365
+ )
366
+ except Exception as e:
367
+ log.info(
368
+ "No token found when while %s auto auto-detection: %s",
369
+ AUTHENTICATOR_WORKLOAD_IDENTITY,
370
+ str(e),
371
+ )
372
+ return connection_parameters
@@ -1,11 +1,13 @@
1
1
  from snowflake.cli._plugins.auth.keypair.commands import app as keypair_app
2
+ from snowflake.cli._plugins.auth.oidc.commands import (
3
+ app as oidc_app,
4
+ )
2
5
  from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
3
- from snowflake.cli.api.feature_flags import FeatureFlag
4
6
 
5
7
  app = SnowTyperFactory(
6
8
  name="auth",
7
9
  help="Manages authentication methods.",
8
- is_hidden=lambda: FeatureFlag.ENABLE_AUTH_KEYPAIR.is_disabled(),
9
10
  )
10
11
 
11
12
  app.add_typer(keypair_app)
13
+ app.add_typer(oidc_app)
@@ -4,6 +4,7 @@ import typer
4
4
  from snowflake.cli._plugins.auth.keypair.manager import AuthManager, PublicKeyProperty
5
5
  from snowflake.cli.api.commands.flags import SecretTypeParser
6
6
  from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
7
+ from snowflake.cli.api.feature_flags import FeatureFlag
7
8
  from snowflake.cli.api.output.types import (
8
9
  CollectionResult,
9
10
  CommandResult,
@@ -16,6 +17,7 @@ from snowflake.cli.api.secure_path import SecurePath
16
17
  app = SnowTyperFactory(
17
18
  name="keypair",
18
19
  help="Manages authentication.",
20
+ is_hidden=lambda: FeatureFlag.ENABLE_AUTH_KEYPAIR.is_disabled(),
19
21
  )
20
22
 
21
23
 
@@ -0,0 +1,13 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,47 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import typer
16
+ from snowflake.cli._app.auth.oidc_providers import (
17
+ OidcProviderTypeWithAuto,
18
+ )
19
+ from snowflake.cli._plugins.auth.oidc.manager import OidcManager
20
+ from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
21
+ from snowflake.cli.api.output.types import MessageResult
22
+
23
+ app = SnowTyperFactory(
24
+ name="oidc",
25
+ help="Manages OIDC authentication.",
26
+ )
27
+
28
+
29
+ AutoProviderTypeOption = typer.Option(
30
+ OidcProviderTypeWithAuto.AUTO.value,
31
+ "--type",
32
+ help=f"Type of OIDC provider to use",
33
+ show_default=False,
34
+ )
35
+
36
+
37
+ @app.command("read-token", requires_connection=False)
38
+ def read_token(
39
+ _type: OidcProviderTypeWithAuto = AutoProviderTypeOption,
40
+ **options,
41
+ ):
42
+ """
43
+ Reads OIDC token based on the specified type.
44
+ Use 'auto' to auto-detect available providers.
45
+ """
46
+ result = OidcManager().read_token(provider_type=_type)
47
+ return MessageResult(result)
@@ -0,0 +1,66 @@
1
+ # Copyright (c) 2024 Snowflake Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from typing import TypeAlias
17
+
18
+ from snowflake.cli._app.auth.errors import OidcProviderError
19
+ from snowflake.cli._app.auth.oidc_providers import (
20
+ OidcProviderType,
21
+ OidcProviderTypeWithAuto,
22
+ auto_detect_oidc_provider,
23
+ get_active_oidc_provider,
24
+ )
25
+ from snowflake.cli.api.exceptions import CliError
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ Providers: TypeAlias = OidcProviderType | OidcProviderTypeWithAuto
31
+
32
+
33
+ class OidcManager:
34
+ """
35
+ Manages OIDC authentication.
36
+
37
+ This class provides methods to read OIDC configurations for authentication.
38
+ """
39
+
40
+ def read_token(
41
+ self,
42
+ provider_type: Providers = OidcProviderTypeWithAuto.AUTO,
43
+ ) -> str:
44
+ """
45
+ Reads OIDC token based on the specified provider type.
46
+
47
+ Args:
48
+ provider_type: Type of provider to read token from ("auto" for auto-detection)
49
+
50
+ Returns:
51
+ Token string or provider information
52
+
53
+ Raises:
54
+ CliError: If token reading fails
55
+ """
56
+ logger.info("Reading OIDC token with provider type: %s", provider_type)
57
+
58
+ try:
59
+ if provider_type == OidcProviderTypeWithAuto.AUTO:
60
+ provider = auto_detect_oidc_provider()
61
+ else:
62
+ provider = get_active_oidc_provider(provider_type.value)
63
+ return provider.get_token()
64
+ except OidcProviderError as e:
65
+ logger.error("OIDC provider error: %s", str(e))
66
+ raise CliError(str(e))