dagster-snowflake 0.13.3rc0__py3-none-any.whl → 0.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,31 @@
1
+ import base64
1
2
  import sys
2
3
  import warnings
4
+ from collections.abc import Iterator, Mapping, Sequence
3
5
  from contextlib import closing, contextmanager
4
-
5
- from dagster import check, resource
6
-
7
- from .configs import define_snowflake_config
6
+ from datetime import datetime
7
+ from typing import Any, Optional, Union
8
+
9
+ import dagster._check as check
10
+ from cryptography.hazmat.backends import default_backend
11
+ from cryptography.hazmat.primitives import serialization
12
+ from dagster import (
13
+ ConfigurableResource,
14
+ IAttachDifferentObjectToOpContext,
15
+ get_dagster_logger,
16
+ resource,
17
+ )
18
+ from dagster._annotations import public
19
+ from dagster._core.definitions.resource_definition import dagster_maintained_resource
20
+ from dagster._core.storage.event_log.sql_event_log import SqlDbConnection
21
+ from dagster._utils.cached_method import cached_method
22
+ from dagster.components.lib.sql_component.sql_client import SQLClient
23
+ from pydantic import Field, model_validator, validator
24
+
25
+ from dagster_snowflake.constants import (
26
+ SNOWFLAKE_PARTNER_CONNECTION_IDENTIFIER,
27
+ SNOWFLAKE_PARTNER_CONNECTION_IDENTIFIER_SQLALCHEMY,
28
+ )
8
29
 
9
30
  try:
10
31
  import snowflake.connector
@@ -20,173 +41,843 @@ except ImportError:
20
41
  raise
21
42
 
22
43
 
23
- class SnowflakeConnection:
24
- def __init__(self, context): # pylint: disable=too-many-locals
25
- # Extract parameters from resource config. Note that we can't pass None values to
26
- # snowflake.connector.connect() because they will override the default values set within the
27
- # connector; remove them from the conn_args dict.
28
- self.connector = context.resource_config.get("connector", None)
44
+ class SnowflakeResource(ConfigurableResource, IAttachDifferentObjectToOpContext, SQLClient):
45
+ """A resource for connecting to the Snowflake data warehouse.
29
46
 
30
- if self.connector == "sqlalchemy":
31
- self.conn_args = {
32
- k: context.resource_config.get(k)
33
- for k in (
34
- "account",
35
- "user",
36
- "password",
37
- "database",
38
- "schema",
39
- "role",
40
- "warehouse",
41
- "cache_column_metadata",
42
- "numpy",
43
- )
44
- if context.resource_config.get(k) is not None
47
+ If connector configuration is not set, SnowflakeResource.get_connection() will return a
48
+ `snowflake.connector.Connection <https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api#object-connection>`__
49
+ object. If connector="sqlalchemy" configuration is set, then SnowflakeResource.get_connection() will
50
+ return a `SQLAlchemy Connection <https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.Connection>`__
51
+ or a `SQLAlchemy raw connection <https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.Engine.raw_connection>`__.
52
+
53
+ A simple example of loading data into Snowflake and subsequently querying that data is shown below:
54
+
55
+ Examples:
56
+ .. code-block:: python
57
+
58
+ from dagster import job, op
59
+ from dagster_snowflake import SnowflakeResource
60
+
61
+ @op
62
+ def get_one(snowflake_resource: SnowflakeResource):
63
+ with snowflake_resource.get_connection() as conn:
64
+ # conn is a snowflake.connector.Connection object
65
+ conn.cursor().execute("SELECT 1")
66
+
67
+ @job
68
+ def my_snowflake_job():
69
+ get_one()
70
+
71
+ my_snowflake_job.execute_in_process(
72
+ resources={
73
+ 'snowflake_resource': SnowflakeResource(
74
+ account=EnvVar("SNOWFLAKE_ACCOUNT"),
75
+ user=EnvVar("SNOWFLAKE_USER"),
76
+ password=EnvVar("SNOWFLAKE_PASSWORD")
77
+ database="MY_DATABASE",
78
+ schema="MY_SCHEMA",
79
+ warehouse="MY_WAREHOUSE"
80
+ )
81
+ }
82
+ )
83
+ """
84
+
85
+ account: Optional[str] = Field(
86
+ default=None,
87
+ description=(
88
+ "Your Snowflake account name. For more details, see the `Snowflake documentation."
89
+ " <https://docs.snowflake.com/developer-guide/python-connector/python-connector-api>`__"
90
+ ),
91
+ )
92
+
93
+ user: str = Field(description="User login name.")
94
+
95
+ password: Optional[str] = Field(default=None, description="User password.")
96
+
97
+ database: Optional[str] = Field(
98
+ default=None,
99
+ description=(
100
+ "Name of the default database to use. After login, you can use ``USE DATABASE`` "
101
+ " to change the database."
102
+ ),
103
+ )
104
+
105
+ schema_: Optional[str] = Field(
106
+ default=None,
107
+ description=(
108
+ "Name of the default schema to use. After login, you can use ``USE SCHEMA`` to "
109
+ "change the schema."
110
+ ),
111
+ alias="schema",
112
+ ) # schema is a reserved word for pydantic
113
+
114
+ role: Optional[str] = Field(
115
+ default=None,
116
+ description=(
117
+ "Name of the default role to use. After login, you can use ``USE ROLE`` to change "
118
+ " the role."
119
+ ),
120
+ )
121
+
122
+ warehouse: Optional[str] = Field(
123
+ default=None,
124
+ description=(
125
+ "Name of the default warehouse to use. After login, you can use ``USE WAREHOUSE`` "
126
+ "to change the role."
127
+ ),
128
+ )
129
+
130
+ private_key: Optional[str] = Field(
131
+ default=None,
132
+ description=(
133
+ "Raw private key to use. See the `Snowflake documentation"
134
+ " <https://docs.snowflake.com/en/user-guide/key-pair-auth.html>`__ for details."
135
+ " Alternately, set private_key_path and private_key_password. To avoid issues with"
136
+ " newlines in the keys, you can optionally base64 encode the key. You can retrieve"
137
+ " the base64 encoded key with this shell command: ``cat rsa_key.p8 | base64``"
138
+ ),
139
+ )
140
+
141
+ private_key_password: Optional[str] = Field(
142
+ default=None,
143
+ description=(
144
+ "Raw private key password to use. See the `Snowflake documentation"
145
+ " <https://docs.snowflake.com/en/user-guide/key-pair-auth.html>`__ for details."
146
+ " Required for both ``private_key`` and ``private_key_path`` if the private key is"
147
+ " encrypted. For unencrypted keys, this config can be omitted or set to None."
148
+ ),
149
+ )
150
+
151
+ private_key_path: Optional[str] = Field(
152
+ default=None,
153
+ description=(
154
+ "Raw private key path to use. See the `Snowflake documentation"
155
+ " <https://docs.snowflake.com/en/user-guide/key-pair-auth.html>`__ for details."
156
+ " Alternately, set the raw private key as ``private_key``."
157
+ ),
158
+ )
159
+
160
+ autocommit: Optional[bool] = Field(
161
+ default=None,
162
+ description=(
163
+ "None by default, which honors the Snowflake parameter AUTOCOMMIT. Set to True "
164
+ "or False to enable or disable autocommit mode in the session, respectively."
165
+ ),
166
+ )
167
+
168
+ client_prefetch_threads: Optional[int] = Field(
169
+ default=None,
170
+ description=(
171
+ "Number of threads used to download the results sets (4 by default). "
172
+ "Increasing the value improves fetch performance but requires more memory."
173
+ ),
174
+ )
175
+
176
+ client_session_keep_alive: Optional[bool] = Field(
177
+ default=None,
178
+ description=(
179
+ "False by default. Set this to True to keep the session active indefinitely, "
180
+ "even if there is no activity from the user. Make certain to call the close method to "
181
+ "terminate the thread properly or the process may hang."
182
+ ),
183
+ )
184
+
185
+ login_timeout: Optional[int] = Field(
186
+ default=None,
187
+ description=(
188
+ "Timeout in seconds for login. By default, 60 seconds. The login request gives "
189
+ 'up after the timeout length if the HTTP response is "success".'
190
+ ),
191
+ )
192
+
193
+ network_timeout: Optional[int] = Field(
194
+ default=None,
195
+ description=(
196
+ "Timeout in seconds for all other operations. By default, none/infinite. A general"
197
+ " request gives up after the timeout length if the HTTP response is not 'success'."
198
+ ),
199
+ )
200
+
201
+ ocsp_response_cache_filename: Optional[str] = Field(
202
+ default=None,
203
+ description=(
204
+ "URI for the OCSP response cache file. By default, the OCSP response cache "
205
+ "file is created in the cache directory."
206
+ ),
207
+ )
208
+
209
+ validate_default_parameters: Optional[bool] = Field(
210
+ default=None,
211
+ description=(
212
+ "If True, raise an exception if the warehouse, database, or schema doesn't exist."
213
+ " Defaults to False."
214
+ ),
215
+ )
216
+
217
+ paramstyle: Optional[str] = Field(
218
+ default=None,
219
+ description=(
220
+ "pyformat by default for client side binding. Specify qmark or numeric to "
221
+ "change bind variable formats for server side binding."
222
+ ),
223
+ )
224
+
225
+ timezone: Optional[str] = Field(
226
+ default=None,
227
+ description=(
228
+ "None by default, which honors the Snowflake parameter TIMEZONE. Set to a "
229
+ "valid time zone (e.g. America/Los_Angeles) to set the session time zone."
230
+ ),
231
+ )
232
+
233
+ connector: Optional[str] = Field(
234
+ default=None,
235
+ description=(
236
+ "Indicate alternative database connection engine. Permissible option is "
237
+ "'sqlalchemy' otherwise defaults to use the Snowflake Connector for Python."
238
+ ),
239
+ is_required=False, # type: ignore
240
+ )
241
+
242
+ cache_column_metadata: Optional[str] = Field(
243
+ default=None,
244
+ description=(
245
+ "Optional parameter when connector is set to sqlalchemy. Snowflake SQLAlchemy takes a"
246
+ " flag ``cache_column_metadata=True`` such that all of column metadata for all tables"
247
+ ' are "cached"'
248
+ ),
249
+ )
250
+
251
+ numpy: Optional[bool] = Field(
252
+ default=None,
253
+ description=(
254
+ "Optional parameter when connector is set to sqlalchemy. To enable fetching "
255
+ "NumPy data types, add numpy=True to the connection parameters."
256
+ ),
257
+ )
258
+
259
+ authenticator: Optional[str] = Field(
260
+ default=None,
261
+ description="Optional parameter to specify the authentication mechanism to use.",
262
+ )
263
+ additional_snowflake_connection_args: Optional[dict[str, Any]] = Field(
264
+ default=None,
265
+ description=(
266
+ "Additional keyword arguments to pass to the snowflake.connector.connect function. For a full list of"
267
+ " available arguments, see the `Snowflake documentation"
268
+ " <https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect>`__."
269
+ " This config will be ignored if using the sqlalchemy connector."
270
+ ),
271
+ )
272
+
273
+ @validator("paramstyle")
274
+ def validate_paramstyle(cls, v: Optional[str]) -> Optional[str]:
275
+ valid_config = ["pyformat", "qmark", "numeric"]
276
+ if v is not None and v not in valid_config:
277
+ raise ValueError(
278
+ "Snowflake Resource: 'paramstyle' configuration value must be one of:"
279
+ f" {','.join(valid_config)}."
280
+ )
281
+ return v
282
+
283
+ @validator("connector")
284
+ def validate_connector(cls, v: Optional[str]) -> Optional[str]:
285
+ if v is not None and v not in ["sqlalchemy", "adbc"]:
286
+ raise ValueError(
287
+ "Snowflake Resource: 'connector' configuration value must be None, sqlalchemy or adbc."
288
+ )
289
+ return v
290
+
291
+ @model_validator(mode="before")
292
+ def validate_authentication(cls, values):
293
+ auths_set = 0
294
+ auths_set += 1 if values.get("password") is not None else 0
295
+ auths_set += 1 if values.get("private_key") is not None else 0
296
+ auths_set += 1 if values.get("private_key_path") is not None else 0
297
+
298
+ # if authenticator is set, there can be 0 or 1 additional auth method;
299
+ # otherwise, ensure at least 1 method is provided
300
+ check.invariant(
301
+ auths_set > 0 or values.get("authenticator") is not None,
302
+ "Missing config: Password, private key, or authenticator authentication required"
303
+ " for Snowflake resource.",
304
+ )
305
+
306
+ # ensure that only 1 non-authenticator method is provided
307
+ check.invariant(
308
+ auths_set <= 1,
309
+ "Incorrect config: Cannot provide both password and private key authentication to"
310
+ " Snowflake Resource.",
311
+ )
312
+
313
+ return values
314
+
315
+ @classmethod
316
+ def _is_dagster_maintained(cls) -> bool:
317
+ return True
318
+
319
+ @property
320
+ @cached_method
321
+ def _connection_args(self) -> Mapping[str, Any]:
322
+ conn_args = {
323
+ k: self._resolved_config_dict.get(k)
324
+ for k in (
325
+ "account",
326
+ "user",
327
+ "password",
328
+ "database",
329
+ "schema",
330
+ "role",
331
+ "warehouse",
332
+ "autocommit",
333
+ "client_prefetch_threads",
334
+ "client_session_keep_alive",
335
+ "login_timeout",
336
+ "network_timeout",
337
+ "ocsp_response_cache_filename",
338
+ "validate_default_parameters",
339
+ "paramstyle",
340
+ "timezone",
341
+ "authenticator",
342
+ )
343
+ if self._resolved_config_dict.get(k) is not None
344
+ }
345
+ if (
346
+ self._resolved_config_dict.get("private_key", None) is not None
347
+ or self._resolved_config_dict.get("private_key_path", None) is not None
348
+ ):
349
+ conn_args["private_key"] = self._snowflake_private_key(self._resolved_config_dict)
350
+
351
+ conn_args["application"] = SNOWFLAKE_PARTNER_CONNECTION_IDENTIFIER
352
+
353
+ if self._resolved_config_dict.get("additional_snowflake_connection_args") is not None:
354
+ conn_args.update(self._resolved_config_dict["additional_snowflake_connection_args"])
355
+ return conn_args
356
+
357
+ @property
358
+ @cached_method
359
+ def _sqlalchemy_connection_args(self) -> Mapping[str, Any]:
360
+ conn_args: dict[str, Any] = {
361
+ k: self._resolved_config_dict.get(k)
362
+ for k in (
363
+ "account",
364
+ "user",
365
+ "password",
366
+ "database",
367
+ "schema",
368
+ "role",
369
+ "warehouse",
370
+ "cache_column_metadata",
371
+ "numpy",
372
+ )
373
+ if self._resolved_config_dict.get(k) is not None
374
+ }
375
+ conn_args["application"] = SNOWFLAKE_PARTNER_CONNECTION_IDENTIFIER_SQLALCHEMY
376
+
377
+ return conn_args
378
+
379
+ @property
380
+ @cached_method
381
+ def _sqlalchemy_engine_args(self) -> Mapping[str, Any]:
382
+ config = self._resolved_config_dict
383
+ sqlalchemy_engine_args = {}
384
+ if (
385
+ config.get("private_key", None) is not None
386
+ or config.get("private_key_path", None) is not None
387
+ ):
388
+ # sqlalchemy passes private key args separately, so store them in a new dict
389
+ sqlalchemy_engine_args["private_key"] = self._snowflake_private_key(config)
390
+ if config.get("authenticator", None) is not None:
391
+ sqlalchemy_engine_args["authenticator"] = config["authenticator"]
392
+
393
+ return sqlalchemy_engine_args
394
+
395
+ @property
396
+ @cached_method
397
+ def _adbc_connection_args(self) -> Mapping[str, Any]:
398
+ config = self._resolved_config_dict
399
+ adbc_engine_args = {}
400
+
401
+ if config.get("account"):
402
+ adbc_engine_args["adbc.snowflake.sql.account"] = config["account"]
403
+ if config.get("user"):
404
+ adbc_engine_args["username"] = config["user"]
405
+ if config.get("password"):
406
+ adbc_engine_args["password"] = config["password"]
407
+ if config.get("database"):
408
+ adbc_engine_args["adbc.snowflake.sql.db"] = config["database"]
409
+ if config.get("schema"):
410
+ adbc_engine_args["adbc.snowflake.sql.schema"] = config["schema"]
411
+ if config.get("role"):
412
+ adbc_engine_args["adbc.snowflake.sql.role"] = config["role"]
413
+ if config.get("warehouse"):
414
+ adbc_engine_args["adbc.snowflake.sql.warehouse"] = config["warehouse"]
415
+
416
+ if config.get("authenticator"):
417
+ auth_mapping = {
418
+ "snowflake": "auth_snowflake",
419
+ "oauth": "auth_oauth",
420
+ "externalbrowser": "auth_ext_browser",
421
+ "okta": "auth_okta",
422
+ "jwt": "auth_jwt",
423
+ "snowflake_jwt": "auth_jwt",
45
424
  }
425
+ auth_type = auth_mapping.get(config["authenticator"].lower(), config["authenticator"])
426
+ adbc_engine_args["adbc.snowflake.sql.auth_type"] = auth_type
427
+
428
+ if config.get("private_key") or config.get("private_key_path"):
429
+ # ADBC expects the raw private key value as bytes for jwt_private_key_pkcs8_value
430
+ adbc_engine_args["adbc.snowflake.sql.auth_type"] = "auth_jwt"
431
+ if config.get("private_key"):
432
+ adbc_engine_args["adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_value"] = (
433
+ config["private_key"]
434
+ )
435
+ elif config.get("private_key_path"):
436
+ adbc_engine_args["adbc.snowflake.sql.client_option.jwt_private_key"] = config[
437
+ "private_key_path"
438
+ ]
439
+
440
+ if config.get("private_key_password"):
441
+ adbc_engine_args[
442
+ "adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_password"
443
+ ] = config["private_key_password"]
444
+
445
+ if config.get("login_timeout"):
446
+ adbc_engine_args["adbc.snowflake.sql.client_option.login_timeout"] = (
447
+ f"{config['login_timeout']}s"
448
+ )
449
+ if config.get("network_timeout"):
450
+ adbc_engine_args["adbc.snowflake.sql.client_option.request_timeout"] = (
451
+ f"{config['network_timeout']}s"
452
+ )
453
+ if config.get("client_session_keep_alive") is not None:
454
+ adbc_engine_args["adbc.snowflake.sql.client_option.keep_session_alive"] = str(
455
+ config["client_session_keep_alive"]
456
+ ).lower()
457
+
458
+ adbc_engine_args["adbc.snowflake.sql.client_option.app_name"] = (
459
+ SNOWFLAKE_PARTNER_CONNECTION_IDENTIFIER
460
+ )
461
+
462
+ if config.get("additional_snowflake_connection_args"):
463
+ for key, value in config["additional_snowflake_connection_args"].items():
464
+ # Allow direct ADBC option names to be passed through
465
+ if key.startswith("adbc.snowflake."):
466
+ adbc_engine_args[key] = value # noqa: PERF403
46
467
 
468
+ return adbc_engine_args
469
+
470
+ def _snowflake_private_key(self, config) -> bytes:
471
+ # If the user has defined a path to a private key, we will use that.
472
+ if config.get("private_key_path", None) is not None:
473
+ # read the file from the path.
474
+ with open(config.get("private_key_path"), "rb") as key:
475
+ private_key = key.read()
47
476
  else:
48
- self.conn_args = {
49
- k: context.resource_config.get(k)
50
- for k in (
51
- "account",
52
- "user",
53
- "password",
54
- "database",
55
- "schema",
56
- "role",
57
- "warehouse",
58
- "autocommit",
59
- "client_prefetch_threads",
60
- "client_session_keep_alive",
61
- "login_timeout",
62
- "network_timeout",
63
- "ocsp_response_cache_filename",
64
- "validate_default_parameters",
65
- "paramstyle",
66
- "timezone",
477
+ private_key = config.get("private_key", None).encode()
478
+
479
+ kwargs = {}
480
+ if config.get("private_key_password", None) is not None:
481
+ kwargs["password"] = config["private_key_password"].encode()
482
+ else:
483
+ kwargs["password"] = None
484
+
485
+ try:
486
+ p_key = serialization.load_pem_private_key(
487
+ private_key, backend=default_backend(), **kwargs
488
+ )
489
+
490
+ # key fails to load, possibly indicating key is base64 encoded
491
+ except ValueError:
492
+ try:
493
+ private_key = base64.b64decode(private_key)
494
+ p_key = serialization.load_pem_private_key(
495
+ private_key, backend=default_backend(), **kwargs
496
+ )
497
+ except ValueError:
498
+ raise ValueError(
499
+ "Unable to load private key. You may need to base64 encode your private key."
500
+ " You can retrieve the base64 encoded key with this shell command: cat"
501
+ " rsa_key.p8 | base64"
67
502
  )
68
- if context.resource_config.get(k) is not None
69
- }
70
503
 
71
- self.autocommit = self.conn_args.get("autocommit", False)
72
- self.log = context.log
504
+ pkb = p_key.private_bytes(
505
+ encoding=serialization.Encoding.DER,
506
+ format=serialization.PrivateFormat.PKCS8,
507
+ encryption_algorithm=serialization.NoEncryption(),
508
+ )
73
509
 
510
+ return pkb
511
+
512
+ @public
74
513
  @contextmanager
75
- def get_connection(self, raw_conn=True):
514
+ def get_connection(
515
+ self, raw_conn: bool = True
516
+ ) -> Iterator[Union[SqlDbConnection, snowflake.connector.SnowflakeConnection]]:
517
+ """Gets a connection to Snowflake as a context manager.
518
+
519
+ If connector configuration is not set, SnowflakeResource.get_connection() will return a
520
+ `snowflake.connector.Connection <https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api#object-connection>`__
521
+ If connector="sqlalchemy" configuration is set, then SnowflakeResource.get_connection() will
522
+ return a `SQLAlchemy Connection <https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.Connection>`__
523
+ or a `SQLAlchemy raw connection <https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.Engine.raw_connection>`__
524
+ if raw_conn=True.
525
+
526
+
527
+ Args:
528
+ raw_conn (bool): If using the sqlalchemy connector, you can set raw_conn to True to create a raw
529
+ connection. Defaults to True.
530
+
531
+ Examples:
532
+ .. code-block:: python
533
+
534
+ @op
535
+ def get_query_status(snowflake: SnowflakeResource, query_id):
536
+ with snowflake.get_connection() as conn:
537
+ # conn is a Snowflake Connection object or a SQLAlchemy Connection if
538
+ # sqlalchemy is specified as the connector in the Snowflake Resource config
539
+
540
+ return conn.get_query_status(query_id)
541
+
542
+ """
76
543
  if self.connector == "sqlalchemy":
544
+ from snowflake.sqlalchemy import URL
77
545
  from sqlalchemy import create_engine
78
- from snowflake.sqlalchemy import URL # pylint: disable=no-name-in-module,import-error
79
546
 
80
- engine = create_engine(URL(**self.conn_args))
547
+ engine = create_engine(
548
+ URL(**self._sqlalchemy_connection_args), connect_args=self._sqlalchemy_engine_args
549
+ )
81
550
  conn = engine.raw_connection() if raw_conn else engine.connect()
82
551
 
83
552
  yield conn
84
553
  conn.close()
85
554
  engine.dispose()
555
+ elif self.connector == "adbc":
556
+ import adbc_driver_snowflake.dbapi
557
+
558
+ conn = adbc_driver_snowflake.dbapi.connect(
559
+ db_kwargs=self._adbc_connection_args, # pyright: ignore[reportArgumentType]
560
+ )
561
+
562
+ yield conn
563
+ conn.close()
86
564
  else:
87
- conn = snowflake.connector.connect(**self.conn_args)
565
+ conn = snowflake.connector.connect(**self._connection_args)
88
566
 
89
567
  yield conn
90
568
  if not self.autocommit:
91
569
  conn.commit()
92
570
  conn.close()
93
571
 
94
- def execute_query(self, sql, parameters=None, fetch_results=False):
572
+ def get_object_to_set_on_execution_context(self) -> Any:
573
+ # Directly create a SnowflakeConnection here for backcompat since the SnowflakeConnection
574
+ # has methods this resource does not have
575
+ return SnowflakeConnection(
576
+ config=self._resolved_config_dict,
577
+ log=get_dagster_logger(),
578
+ snowflake_connection_resource=self,
579
+ )
580
+
581
+ def connect_and_execute(self, sql: str) -> None:
582
+ with self.get_connection() as conn:
583
+ conn.cursor().execute(sql)
584
+
585
+
586
+ class SnowflakeConnection:
587
+ """A connection to Snowflake that can execute queries. In general this class should not be
588
+ directly instantiated, but rather used as a resource in an op or asset via the
589
+ :py:func:`snowflake_resource`.
590
+
591
+ Note that the SnowflakeConnection is only used by the snowflake_resource. The Pythonic SnowflakeResource does
592
+ not use this SnowflakeConnection class.
593
+ """
594
+
595
+ def __init__(
596
+ self, config: Mapping[str, str], log, snowflake_connection_resource: SnowflakeResource
597
+ ):
598
+ self.snowflake_connection_resource = snowflake_connection_resource
599
+ self.log = log
600
+
601
+ @public
602
+ @contextmanager
603
+ def get_connection(
604
+ self, raw_conn: bool = True
605
+ ) -> Iterator[Union[SqlDbConnection, snowflake.connector.SnowflakeConnection]]:
606
+ """Gets a connection to Snowflake as a context manager.
607
+
608
+ If using the execute_query, execute_queries, or load_table_from_local_parquet methods,
609
+ you do not need to create a connection using this context manager.
610
+
611
+ Args:
612
+ raw_conn (bool): If using the sqlalchemy connector, you can set raw_conn to True to create a raw
613
+ connection. Defaults to True.
614
+
615
+ Examples:
616
+ .. code-block:: python
617
+
618
+ @op(
619
+ required_resource_keys={"snowflake"}
620
+ )
621
+ def get_query_status(query_id):
622
+ with context.resources.snowflake.get_connection() as conn:
623
+ # conn is a Snowflake Connection object or a SQLAlchemy Connection if
624
+ # sqlalchemy is specified as the connector in the Snowflake Resource config
625
+
626
+ return conn.get_query_status(query_id)
627
+
628
+ """
629
+ with self.snowflake_connection_resource.get_connection(raw_conn=raw_conn) as conn:
630
+ yield conn
631
+
632
+ @public
633
+ def execute_query(
634
+ self,
635
+ sql: str,
636
+ parameters: Optional[Union[Sequence[Any], Mapping[Any, Any]]] = None,
637
+ fetch_results: bool = False,
638
+ use_pandas_result: bool = False,
639
+ ):
640
+ """Execute a query in Snowflake.
641
+
642
+ Args:
643
+ sql (str): the query to be executed
644
+ parameters (Optional[Union[Sequence[Any], Mapping[Any, Any]]]): Parameters to be passed to the query. See the
645
+ `Snowflake documentation <https://docs.snowflake.com/en/user-guide/python-connector-example.html#binding-data>`__
646
+ for more information.
647
+ fetch_results (bool): If True, will return the result of the query. Defaults to False. If True
648
+ and use_pandas_result is also True, results will be returned as a Pandas DataFrame.
649
+ use_pandas_result (bool): If True, will return the result of the query as a Pandas DataFrame.
650
+ Defaults to False. If fetch_results is False and use_pandas_result is True, an error will be
651
+ raised.
652
+
653
+ Returns:
654
+ The result of the query if fetch_results or use_pandas_result is True, otherwise returns None
655
+
656
+ Examples:
657
+ .. code-block:: python
658
+
659
+ @op
660
+ def drop_database(snowflake: SnowflakeResource):
661
+ snowflake.execute_query(
662
+ "DROP DATABASE IF EXISTS MY_DATABASE"
663
+ )
664
+ """
95
665
  check.str_param(sql, "sql")
96
- check.opt_dict_param(parameters, "parameters")
666
+ check.opt_inst_param(parameters, "parameters", (list, dict))
97
667
  check.bool_param(fetch_results, "fetch_results")
668
+ if not fetch_results and use_pandas_result:
669
+ check.failed("If use_pandas_result is True, fetch_results must also be True.")
98
670
 
99
671
  with self.get_connection() as conn:
100
672
  with closing(conn.cursor()) as cursor:
101
- if sys.version_info[0] < 3:
102
- sql = sql.encode("utf-8")
103
-
104
673
  self.log.info("Executing query: " + sql)
105
- cursor.execute(sql, parameters) # pylint: disable=E1101
674
+ parameters = dict(parameters) if isinstance(parameters, Mapping) else parameters
675
+ cursor.execute(sql, parameters)
676
+ if use_pandas_result:
677
+ return cursor.fetch_pandas_all()
106
678
  if fetch_results:
107
- return cursor.fetchall() # pylint: disable=E1101
108
-
109
- def execute_queries(self, sql_queries, parameters=None, fetch_results=False):
110
- check.list_param(sql_queries, "sql_queries", of_type=str)
111
- check.opt_dict_param(parameters, "parameters")
679
+ return cursor.fetchall()
680
+
681
+ @public
682
+ def execute_queries(
683
+ self,
684
+ sql_queries: Sequence[str],
685
+ parameters: Optional[Union[Sequence[Any], Mapping[Any, Any]]] = None,
686
+ fetch_results: bool = False,
687
+ use_pandas_result: bool = False,
688
+ ) -> Optional[Sequence[Any]]:
689
+ """Execute multiple queries in Snowflake.
690
+
691
+ Args:
692
+ sql_queries (str): List of queries to be executed in series
693
+ parameters (Optional[Union[Sequence[Any], Mapping[Any, Any]]]): Parameters to be passed to every query. See the
694
+ `Snowflake documentation <https://docs.snowflake.com/en/user-guide/python-connector-example.html#binding-data>`__
695
+ for more information.
696
+ fetch_results (bool): If True, will return the results of the queries as a list. Defaults to False. If True
697
+ and use_pandas_result is also True, results will be returned as Pandas DataFrames.
698
+ use_pandas_result (bool): If True, will return the results of the queries as a list of a Pandas DataFrames.
699
+ Defaults to False. If fetch_results is False and use_pandas_result is True, an error will be
700
+ raised.
701
+
702
+ Returns:
703
+ The results of the queries as a list if fetch_results or use_pandas_result is True,
704
+ otherwise returns None
705
+
706
+ Examples:
707
+ .. code-block:: python
708
+
709
+ @op
710
+ def create_fresh_database(snowflake: SnowflakeResource):
711
+ queries = ["DROP DATABASE IF EXISTS MY_DATABASE", "CREATE DATABASE MY_DATABASE"]
712
+ snowflake.execute_queries(
713
+ sql_queries=queries
714
+ )
715
+
716
+ """
717
+ check.sequence_param(sql_queries, "sql_queries", of_type=str)
718
+ check.opt_inst_param(parameters, "parameters", (list, dict))
112
719
  check.bool_param(fetch_results, "fetch_results")
720
+ if not fetch_results and use_pandas_result:
721
+ check.failed("If use_pandas_result is True, fetch_results must also be True.")
113
722
 
114
- results = []
723
+ results: list[Any] = []
115
724
  with self.get_connection() as conn:
116
725
  with closing(conn.cursor()) as cursor:
117
- for sql in sql_queries:
118
- if sys.version_info[0] < 3:
119
- sql = sql.encode("utf-8")
726
+ for raw_sql in sql_queries:
727
+ sql = raw_sql.encode("utf-8") if sys.version_info[0] < 3 else raw_sql
120
728
  self.log.info("Executing query: " + sql)
121
- cursor.execute(sql, parameters) # pylint: disable=E1101
122
- if fetch_results:
123
- results.append(cursor.fetchall()) # pylint: disable=E1101
124
-
125
- return results if fetch_results else None
126
-
127
- def load_table_from_local_parquet(self, src, table):
729
+ parameters = dict(parameters) if isinstance(parameters, Mapping) else parameters
730
+ cursor.execute(sql, parameters)
731
+ if use_pandas_result:
732
+ results = results.append(cursor.fetch_pandas_all()) # type: ignore
733
+ elif fetch_results:
734
+ results.append(cursor.fetchall())
735
+
736
+ return results if len(results) > 0 else None
737
+
738
+ @public
739
+ def load_table_from_local_parquet(self, src: str, table: str):
740
+ """Stores the content of a parquet file to a Snowflake table.
741
+
742
+ Args:
743
+ src (str): the name of the file to store in Snowflake
744
+ table (str): the name of the table to store the data. If the table does not exist, it will
745
+ be created. Otherwise the contents of the table will be replaced with the data in src
746
+
747
+ Examples:
748
+ .. code-block:: python
749
+
750
+ import pandas as pd
751
+ import pyarrow as pa
752
+ import pyarrow.parquet as pq
753
+
754
+ @op
755
+ def write_parquet_file(snowflake: SnowflakeResource):
756
+ df = pd.DataFrame({"one": [1, 2, 3], "ten": [11, 12, 13]})
757
+ table = pa.Table.from_pandas(df)
758
+ pq.write_table(table, "example.parquet')
759
+ snowflake.load_table_from_local_parquet(
760
+ src="example.parquet",
761
+ table="MY_TABLE"
762
+ )
763
+
764
+ """
128
765
  check.str_param(src, "src")
129
766
  check.str_param(table, "table")
130
767
 
131
768
  sql_queries = [
132
- "CREATE OR REPLACE TABLE {table} ( data VARIANT DEFAULT NULL);".format(table=table),
769
+ f"CREATE OR REPLACE TABLE {table} ( data VARIANT DEFAULT NULL);",
133
770
  "CREATE OR REPLACE FILE FORMAT parquet_format TYPE = 'parquet';",
134
- "PUT {src} @%{table};".format(src=src, table=table),
135
- "COPY INTO {table} FROM @%{table} FILE_FORMAT = (FORMAT_NAME = 'parquet_format');".format(
136
- table=table
137
- ),
771
+ f"PUT {src} @%{table};",
772
+ f"COPY INTO {table} FROM @%{table} FILE_FORMAT = (FORMAT_NAME = 'parquet_format');",
138
773
  ]
139
774
 
140
775
  self.execute_queries(sql_queries)
141
776
 
142
777
 
778
+ @dagster_maintained_resource
143
779
  @resource(
144
- config_schema=define_snowflake_config(),
780
+ config_schema=SnowflakeResource.to_config_schema(),
145
781
  description="This resource is for connecting to the Snowflake data warehouse",
146
782
  )
147
- def snowflake_resource(context):
148
- """A resource for connecting to the Snowflake data warehouse.
783
+ def snowflake_resource(context) -> SnowflakeConnection:
784
+ """A resource for connecting to the Snowflake data warehouse. The returned resource object is an
785
+ instance of :py:class:`SnowflakeConnection`.
149
786
 
150
787
  A simple example of loading data into Snowflake and subsequently querying that data is shown below:
151
788
 
152
789
  Examples:
153
-
154
- .. code-block:: python
155
-
156
- from dagster import execute_pipeline, pipeline, DependencyDefinition, ModeDefinition
157
- from dagster_snowflake import snowflake_resource
158
-
159
- @op(required_resource_keys={'snowflake'})
160
- def get_one(context):
161
- context.resources.snowflake.execute_query('SELECT 1')
162
-
163
- @graph
164
- def my_snowflake_graph():
165
- get_one()
166
-
167
- my_snowflake_graph.to_job(
168
- resources={'snowflake': snowflake_resource}
169
- ).execute_in_process(
170
- run_config={
171
- 'resources': {
172
- 'snowflake': {
173
- 'config': {
174
- 'account': {'env': 'SNOWFLAKE_ACCOUNT'},
175
- 'user': {'env': 'SNOWFLAKE_USER'},
176
- 'password': {'env': 'SNOWFLAKE_PASSWORD'},
177
- 'database': {'env': 'SNOWFLAKE_DATABASE'},
178
- 'schema': {'env': 'SNOWFLAKE_SCHEMA'},
179
- 'warehouse': {'env': 'SNOWFLAKE_WAREHOUSE'},
790
+ .. code-block:: python
791
+
792
+ from dagster import job, op
793
+ from dagster_snowflake import snowflake_resource
794
+
795
+ @op(required_resource_keys={'snowflake'})
796
+ def get_one(context):
797
+ context.resources.snowflake.execute_query('SELECT 1')
798
+
799
+ @job(resource_defs={'snowflake': snowflake_resource})
800
+ def my_snowflake_job():
801
+ get_one()
802
+
803
+ my_snowflake_job.execute_in_process(
804
+ run_config={
805
+ 'resources': {
806
+ 'snowflake': {
807
+ 'config': {
808
+ 'account': {'env': 'SNOWFLAKE_ACCOUNT'},
809
+ 'user': {'env': 'SNOWFLAKE_USER'},
810
+ 'password': {'env': 'SNOWFLAKE_PASSWORD'},
811
+ 'database': {'env': 'SNOWFLAKE_DATABASE'},
812
+ 'schema': {'env': 'SNOWFLAKE_SCHEMA'},
813
+ 'warehouse': {'env': 'SNOWFLAKE_WAREHOUSE'},
814
+ }
180
815
  }
181
816
  }
182
817
  }
183
- }
184
- )
818
+ )
819
+ """
820
+ snowflake_resource = SnowflakeResource.from_resource_context(context)
821
+ return SnowflakeConnection(
822
+ config=context, log=context.log, snowflake_connection_resource=snowflake_resource
823
+ )
824
+
185
825
 
826
+ def fetch_last_updated_timestamps(
827
+ *,
828
+ snowflake_connection: Union[SqlDbConnection, snowflake.connector.SnowflakeConnection],
829
+ schema: str,
830
+ tables: Sequence[str],
831
+ database: Optional[str] = None,
832
+ ignore_missing_tables: Optional[bool] = False,
833
+ ) -> Mapping[str, datetime]:
834
+ """Fetch the last updated times of a list of tables in Snowflake.
835
+
836
+ If the underlying query to fetch the last updated time returns no results, a ValueError will be raised.
837
+
838
+ Args:
839
+ snowflake_connection (Union[SqlDbConnection, SnowflakeConnection]): A connection to Snowflake.
840
+ Accepts either a SnowflakeConnection or a sqlalchemy connection object,
841
+ which are the two types of connections emittable from the snowflake resource.
842
+ schema (str): The schema of the tables to fetch the last updated time for.
843
+ tables (Sequence[str]): A list of table names to fetch the last updated time for.
844
+ database (Optional[str]): The database of the table. Only required if the connection
845
+ has not been set with a database.
846
+ ignore_missing_tables (Optional[bool]): If True, tables not found in Snowflake
847
+ will be excluded from the result.
848
+
849
+ Returns:
850
+ Mapping[str, datetime]: A dictionary of table names to their last updated time in UTC.
186
851
  """
187
- return SnowflakeConnection(context)
852
+ check.invariant(len(tables) > 0, "Must provide at least one table name to query upon.")
853
+ # Table names in snowflake's information schema are stored in uppercase
854
+ uppercase_tables = [table.upper() for table in tables]
855
+ tables_str = ", ".join([f"'{table_name}'" for table_name in uppercase_tables])
856
+ fully_qualified_table_name = (
857
+ f"{database}.information_schema.tables" if database else "information_schema.tables"
858
+ )
188
859
 
860
+ query = f"""
861
+ SELECT table_name, CONVERT_TIMEZONE('UTC', last_altered) AS last_altered
862
+ FROM {fully_qualified_table_name}
863
+ WHERE table_schema = '{schema}' AND table_name IN ({tables_str});
864
+ """
865
+ result = snowflake_connection.cursor().execute(query)
866
+ if not result:
867
+ raise ValueError("No results returned from Snowflake update time query.")
868
+
869
+ result_mapping = {table_name: last_altered for table_name, last_altered in result}
870
+ result_correct_case = {}
871
+ for table_name in tables:
872
+ if table_name.upper() not in result_mapping:
873
+ if ignore_missing_tables:
874
+ continue
875
+ raise ValueError(f"Table {table_name} could not be found.")
876
+ last_altered = result_mapping[table_name.upper()]
877
+ check.invariant(
878
+ isinstance(last_altered, datetime),
879
+ "Expected last_altered to be a datetime, but it was not.",
880
+ )
881
+ result_correct_case[table_name] = last_altered
189
882
 
190
- def _filter_password(args):
191
- """Remove password from connection args for logging"""
192
- return {k: v for k, v in args.items() if k != "password"}
883
+ return result_correct_case