flyteplugins-snowflake 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,58 @@
1
+ """
2
+ Key features:
3
+
4
+ - Parameterized SQL queries with typed inputs
5
+ - Key-pair and password-based authentication
6
+ - Returns query results as DataFrames
7
+ - Automatic links to the Snowflake query dashboard in the Flyte UI
8
+ - Query cancellation on task abort
9
+
10
+ Basic usage example:
11
+ ```python
12
+ import flyte
13
+ from flyte.io import DataFrame
14
+ from flyteplugins.snowflake import Snowflake, SnowflakeConfig
15
+
16
+ config = SnowflakeConfig(
17
+ account="myorg-myaccount",
18
+ user="flyte_user",
19
+ database="ANALYTICS",
20
+ schema="PUBLIC",
21
+ warehouse="COMPUTE_WH",
22
+ )
23
+
24
+ count_users = Snowflake(
25
+ name="count_users",
26
+ query_template="SELECT COUNT(*) FROM users",
27
+ plugin_config=config,
28
+ output_dataframe_type=DataFrame,
29
+ )
30
+
31
+ flyte.TaskEnvironment.from_task("snowflake_env", count_users)
32
+
33
+ if __name__ == "__main__":
34
+ flyte.init_from_config()
35
+
36
+ # Run locally (connector runs in-process, requires credentials and packages locally)
37
+ run = flyte.with_runcontext(mode="local").run(count_users)
38
+
39
+ # Run remotely (connector runs on the control plane)
40
+ run = flyte.with_runcontext(mode="remote").run(count_users)
41
+
42
+ print(run.url)
43
+ ```
44
+ """
45
+
46
+ from flyte.io._dataframe.dataframe import DataFrameTransformerEngine
47
+
48
+ from flyteplugins.snowflake.connector import SnowflakeConnector
49
+ from flyteplugins.snowflake.dataframe import (
50
+ PandasToSnowflakeEncodingHandlers,
51
+ SnowflakeToPandasDecodingHandler,
52
+ )
53
+ from flyteplugins.snowflake.task import Snowflake, SnowflakeConfig
54
+
55
+ DataFrameTransformerEngine.register(PandasToSnowflakeEncodingHandlers())
56
+ DataFrameTransformerEngine.register(SnowflakeToPandasDecodingHandler())
57
+
58
+ __all__ = ["Snowflake", "SnowflakeConfig", "SnowflakeConnector"]
@@ -0,0 +1,404 @@
1
+ import asyncio
2
+ import re
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ from async_lru import alru_cache
7
+ from flyte import logger
8
+ from flyte.connectors import AsyncConnector, ConnectorRegistry, Resource, ResourceMeta
9
+ from flyte.connectors.utils import convert_to_flyte_phase
10
+ from flyte.io import DataFrame
11
+ from flyteidl2.core.execution_pb2 import TaskExecution, TaskLog
12
+ from flyteidl2.core.tasks_pb2 import TaskTemplate
13
+ from google.protobuf import json_format
14
+
15
+ from snowflake.connector import SnowflakeConnection
16
+ from snowflake.connector import connect as snowflake_connect
17
+
18
+ TASK_TYPE = "snowflake"
19
+
20
+
21
+ @dataclass
22
+ class SnowflakeJobMetadata(ResourceMeta):
23
+ """
24
+ Metadata for a Snowflake query job.
25
+
26
+ Attributes:
27
+ account: Snowflake account identifier.
28
+ user: Snowflake user name.
29
+ database: Snowflake database name.
30
+ schema: Snowflake schema name.
31
+ warehouse: Snowflake warehouse name.
32
+ query_id: Unique identifier for the submitted query.
33
+ has_output: Indicates if the query produces output.
34
+ connection_kwargs: Additional connection parameters.
35
+ """
36
+
37
+ account: str
38
+ user: str
39
+ database: str
40
+ schema: str
41
+ warehouse: str
42
+ query_id: str
43
+ has_output: bool
44
+ connection_kwargs: Optional[Dict[str, Any]] = None
45
+
46
+
47
+ def _get_private_key(private_key_content: str, private_key_passphrase: Optional[str] = None) -> bytes:
48
+ """
49
+ Decode the private key from the secret and return it in DER format.
50
+
51
+ Args:
52
+ private_key_content: The private key content in PEM format.
53
+ private_key_passphrase: The passphrase for the private key, if any.
54
+
55
+ Returns:
56
+ The private key in DER format.
57
+ """
58
+ from cryptography.hazmat.backends import default_backend
59
+ from cryptography.hazmat.primitives import serialization
60
+
61
+ private_key_bytes = private_key_content.strip().encode()
62
+ password = private_key_passphrase.encode() if private_key_passphrase else None
63
+
64
+ private_key = serialization.load_pem_private_key(
65
+ private_key_bytes,
66
+ password=password,
67
+ backend=default_backend(),
68
+ )
69
+
70
+ return private_key.private_bytes(
71
+ encoding=serialization.Encoding.DER,
72
+ format=serialization.PrivateFormat.PKCS8,
73
+ encryption_algorithm=serialization.NoEncryption(),
74
+ )
75
+
76
+
77
+ @alru_cache
78
+ async def _get_snowflake_connection(
79
+ account: str,
80
+ user: str,
81
+ database: str,
82
+ schema: str,
83
+ warehouse: str,
84
+ private_key_content: Optional[str] = None,
85
+ private_key_passphrase: Optional[str] = None,
86
+ **connection_kwargs,
87
+ ) -> SnowflakeConnection:
88
+ """
89
+ Create and return a Snowflake connection.
90
+
91
+ Supports private key authentication (recommended) and other auth methods via `connection_kwargs`.
92
+ See: https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api
93
+
94
+ Args:
95
+ account: Snowflake account identifier.
96
+ user: Snowflake user name.
97
+ database: Snowflake database name.
98
+ schema: Snowflake schema name.
99
+ warehouse: Snowflake warehouse name.
100
+ private_key_content: The private key content in PEM format, if any.
101
+ private_key_passphrase: The passphrase for the private key, if any.
102
+ **connection_kwargs: Additional connection parameters.
103
+
104
+ Returns:
105
+ An active Snowflake connection.
106
+ """
107
+
108
+ def _create_connection():
109
+ connection_params = {
110
+ "account": account,
111
+ "user": user,
112
+ "database": database,
113
+ "schema": schema,
114
+ "warehouse": warehouse,
115
+ **connection_kwargs,
116
+ }
117
+
118
+ # Add private key authentication if provided
119
+ if private_key_content:
120
+ private_key = _get_private_key(private_key_content, private_key_passphrase)
121
+ connection_params["private_key"] = private_key
122
+
123
+ # Let Snowflake connector validate authentication requirements
124
+ return snowflake_connect(**connection_params)
125
+
126
+ loop = asyncio.get_running_loop()
127
+ return await loop.run_in_executor(None, _create_connection)
128
+
129
+
130
+ def _construct_query_link(account: str, query_id: str) -> str:
131
+ """Construct a Snowflake console link for the query.
132
+
133
+ Args:
134
+ account: Snowflake account identifier.
135
+ query_id: Unique identifier for the submitted query.
136
+
137
+ Returns:
138
+ URL to the Snowflake query dashboard for the given query ID.
139
+ """
140
+ if "-" in account:
141
+ parts = account.split("-", 1)
142
+ if len(parts) == 2:
143
+ org_name, account_name = parts
144
+ base_url = f"https://app.snowflake.com/{org_name}/{account_name}"
145
+ else:
146
+ base_url = f"https://app.snowflake.com/{account}"
147
+ else:
148
+ # Simple account locator
149
+ base_url = f"https://app.snowflake.com/{account}"
150
+
151
+ return f"{base_url}/#/compute/history/queries/{query_id}/detail"
152
+
153
+
154
+ def _expand_batch_query(query: str, inputs: Dict[str, List[Any]]) -> Tuple[str, Dict[str, Any]]:
155
+ """
156
+ Expand a parameterized INSERT query with list inputs into a multi-row VALUES statement.
157
+
158
+ Example:
159
+ query = "INSERT INTO t (a, b) VALUES (%(a)s, %(b)s)"
160
+ inputs = {"a": [1, 2], "b": ["x", "y"]}
161
+
162
+ Produces:
163
+ "INSERT INTO t (a, b) VALUES (%(a_0)s, %(b_0)s), (%(a_1)s, %(b_1)s)"
164
+ flat_params = {"a_0": 1, "b_0": "x", "a_1": 2, "b_1": "y"}
165
+
166
+ Constraints:
167
+ - Query must contain exactly one VALUES (...) clause
168
+ - All input lists must have the same non-zero length
169
+ """
170
+ keys = list(inputs.keys())
171
+
172
+ lengths = {len(v) for v in inputs.values()}
173
+ if len(lengths) != 1:
174
+ raise ValueError("All batch input lists must have the same length.")
175
+
176
+ n_rows = lengths.pop()
177
+ if n_rows == 0:
178
+ raise ValueError("Batch inputs must contain at least one row.")
179
+
180
+ # Find all VALUES (...) clauses (case-insensitive)
181
+ # The inner group handles %(...)s placeholders which contain literal parentheses
182
+ matches = list(re.finditer(r"VALUES\s*\(((?:[^()]*|\([^)]*\))*)\)", query, re.IGNORECASE))
183
+ if not matches:
184
+ raise ValueError("Batch inputs require a query template with a VALUES (...) clause.")
185
+ if len(matches) > 1:
186
+ raise ValueError("Batch query expansion supports exactly one VALUES (...) clause.")
187
+
188
+ match = matches[0]
189
+ values_content = match.group(1)
190
+
191
+ value_rows = []
192
+ flat_params: Dict[str, Any] = {}
193
+
194
+ for i in range(n_rows):
195
+ row_content = values_content
196
+ for key in keys:
197
+ row_content = re.sub(
198
+ rf"%\({re.escape(key)}\)s",
199
+ f"%({key}_{i})s",
200
+ row_content,
201
+ )
202
+ flat_params[f"{key}_{i}"] = inputs[key][i]
203
+ value_rows.append(f"({row_content})")
204
+
205
+ before = query[: match.start()]
206
+ after = query[match.end() :]
207
+
208
+ expanded_query = before + "VALUES " + ", ".join(value_rows) + after
209
+ return expanded_query, flat_params
210
+
211
+
212
+ class SnowflakeConnector(AsyncConnector):
213
+ name = "Snowflake Connector"
214
+ task_type_name = TASK_TYPE
215
+ metadata_type = SnowflakeJobMetadata
216
+
217
+ async def create(
218
+ self,
219
+ task_template: TaskTemplate,
220
+ inputs: Optional[Dict[str, Any]] = None,
221
+ snowflake_private_key: Optional[str] = None,
222
+ snowflake_private_key_passphrase: Optional[str] = None,
223
+ **kwargs,
224
+ ) -> SnowflakeJobMetadata:
225
+ """
226
+ Submit a query to Snowflake asynchronously.
227
+
228
+ Args:
229
+ task_template: The Flyte task template containing the SQL query and configuration.
230
+ inputs: Optional dictionary of input parameters for parameterized queries.
231
+ snowflake_private_key: The private key content set as a Flyte secret.
232
+ snowflake_private_key_passphrase: The passphrase for the private key set as a Flyte secret, if any.
233
+
234
+ Returns:
235
+ A SnowflakeJobMetadata object containing the query ID and link to the query dashboard.
236
+ """
237
+ custom = json_format.MessageToDict(task_template.custom) if task_template.custom else {}
238
+
239
+ account = custom.get("account")
240
+ if not account:
241
+ raise ValueError("Missing Snowflake account. Set it through task configuration.")
242
+
243
+ user = custom.get("user")
244
+ database = custom.get("database")
245
+ schema = custom.get("schema", "PUBLIC")
246
+ warehouse = custom.get("warehouse")
247
+
248
+ if not all([user, database, warehouse]):
249
+ raise ValueError("User, database and warehouse must be specified in the task configuration.")
250
+
251
+ # Get additional connection parameters from custom config
252
+ connection_kwargs = custom.get("connection_kwargs", {})
253
+
254
+ conn = await _get_snowflake_connection(
255
+ account=account,
256
+ user=user,
257
+ database=database,
258
+ schema=schema,
259
+ warehouse=warehouse,
260
+ private_key_content=snowflake_private_key,
261
+ private_key_passphrase=snowflake_private_key_passphrase,
262
+ **connection_kwargs,
263
+ )
264
+
265
+ query = task_template.sql.statement
266
+ batch = custom.get("batch", False)
267
+
268
+ def _execute_query():
269
+ cursor = conn.cursor()
270
+
271
+ if batch and inputs:
272
+ expanded_query, flat_params = _expand_batch_query(query, inputs)
273
+ cursor.execute_async(expanded_query, flat_params)
274
+ else:
275
+ cursor.execute_async(query, inputs)
276
+
277
+ query_id = cursor.sfqid
278
+ cursor.close()
279
+
280
+ return query_id
281
+
282
+ loop = asyncio.get_running_loop()
283
+ query_id = await loop.run_in_executor(None, _execute_query)
284
+
285
+ logger.info(f"Snowflake query submitted with ID: {query_id}")
286
+
287
+ return SnowflakeJobMetadata(
288
+ account=account,
289
+ user=user,
290
+ database=database,
291
+ schema=schema,
292
+ warehouse=warehouse,
293
+ query_id=query_id,
294
+ has_output=task_template.interface.outputs is not None
295
+ and len(task_template.interface.outputs.variables) > 0,
296
+ connection_kwargs=connection_kwargs,
297
+ )
298
+
299
+ async def get(
300
+ self,
301
+ resource_meta: SnowflakeJobMetadata,
302
+ snowflake_private_key: Optional[str] = None,
303
+ snowflake_private_key_passphrase: Optional[str] = None,
304
+ **kwargs,
305
+ ) -> Resource:
306
+ """
307
+ Poll the status of a Snowflake query.
308
+
309
+ Args:
310
+ resource_meta: The SnowflakeJobMetadata containing the query ID.
311
+ snowflake_private_key: The private key content set as a Flyte secret.
312
+ snowflake_private_key_passphrase: The passphrase for the private key set as a Flyte secret, if any.
313
+
314
+ Returns:
315
+ A Resource object containing the query results and a link to the query dashboard.
316
+ """
317
+ conn = await _get_snowflake_connection(
318
+ account=resource_meta.account,
319
+ user=resource_meta.user,
320
+ database=resource_meta.database,
321
+ schema=resource_meta.schema,
322
+ warehouse=resource_meta.warehouse,
323
+ private_key_content=snowflake_private_key,
324
+ private_key_passphrase=snowflake_private_key_passphrase,
325
+ **(resource_meta.connection_kwargs or {}),
326
+ )
327
+
328
+ log_link = TaskLog(
329
+ uri=_construct_query_link(resource_meta.account, resource_meta.query_id),
330
+ name="Snowflake Dashboard",
331
+ ready=True,
332
+ link_type=TaskLog.DASHBOARD,
333
+ )
334
+
335
+ def _get_query_status():
336
+ try:
337
+ status = conn.get_query_status_throw_if_error(resource_meta.query_id)
338
+ return status, None
339
+ except Exception as e:
340
+ return None, str(e)
341
+
342
+ loop = asyncio.get_running_loop()
343
+ status, error = await loop.run_in_executor(None, _get_query_status)
344
+
345
+ if error:
346
+ logger.error(f"Snowflake query failed: {error}")
347
+ return Resource(phase=TaskExecution.FAILED, message=error, log_links=[log_link])
348
+
349
+ # Map Snowflake status to Flyte phase
350
+ # Snowflake statuses: RUNNING, SUCCESS, FAILED_WITH_ERROR, ABORTING, etc.
351
+ cur_phase = convert_to_flyte_phase(status.name)
352
+ outputs = None
353
+
354
+ if cur_phase == TaskExecution.SUCCEEDED and resource_meta.has_output:
355
+ # Construct the output URI for the results
356
+ output_location = (
357
+ f"snowflake://{resource_meta.user}/{resource_meta.account}/{resource_meta.warehouse}/"
358
+ f"{resource_meta.database}/{resource_meta.schema}/{resource_meta.query_id}"
359
+ )
360
+ outputs = {"results": DataFrame(uri=output_location)}
361
+
362
+ return Resource(phase=cur_phase, message=status.name, log_links=[log_link], outputs=outputs)
363
+
364
+ async def delete(
365
+ self,
366
+ resource_meta: SnowflakeJobMetadata,
367
+ snowflake_private_key: Optional[str] = None,
368
+ snowflake_private_key_passphrase: Optional[str] = None,
369
+ **kwargs,
370
+ ):
371
+ """
372
+ Cancel a running Snowflake query.
373
+
374
+ Args:
375
+ resource_meta: The SnowflakeJobMetadata containing the query ID.
376
+ snowflake_private_key: The private key content set as a Flyte secret.
377
+ snowflake_private_key_passphrase: The passphrase for the private key set as a Flyte secret, if any.
378
+ """
379
+ conn = await _get_snowflake_connection(
380
+ account=resource_meta.account,
381
+ user=resource_meta.user,
382
+ database=resource_meta.database,
383
+ schema=resource_meta.schema,
384
+ warehouse=resource_meta.warehouse,
385
+ private_key_content=snowflake_private_key,
386
+ private_key_passphrase=snowflake_private_key_passphrase,
387
+ **(resource_meta.connection_kwargs or {}),
388
+ )
389
+
390
+ def _cancel_query():
391
+ cursor = conn.cursor()
392
+ try:
393
+ cursor.execute(f"SELECT SYSTEM$CANCEL_QUERY('{resource_meta.query_id}')")
394
+ finally:
395
+ cursor.close()
396
+ conn.close()
397
+
398
+ loop = asyncio.get_running_loop()
399
+ await loop.run_in_executor(None, _cancel_query)
400
+
401
+ logger.info(f"Snowflake query {resource_meta.query_id} cancelled")
402
+
403
+
404
+ ConnectorRegistry.register(SnowflakeConnector())
@@ -0,0 +1,124 @@
1
+ import os
2
+ import re
3
+ import typing
4
+ from typing import Optional
5
+
6
+ from flyte._utils import lazy_module
7
+ from flyte.io._dataframe.dataframe import DataFrame, DataFrameDecoder, DataFrameEncoder
8
+ from flyteidl2.core import literals_pb2, types_pb2
9
+
10
+ if typing.TYPE_CHECKING:
11
+ import pandas as pd
12
+
13
+ import snowflake.connector
14
+ else:
15
+ pd = lazy_module("pandas")
16
+
17
+ SNOWFLAKE = "snowflake"
18
+ PROTOCOL_SEP = "\\/|://|:"
19
+
20
+
21
+ def _get_private_key(private_key_content: str, private_key_passphrase: Optional[str] = None) -> bytes:
22
+ """Decode a PEM private key and return it in DER format."""
23
+ from cryptography.hazmat.backends import default_backend
24
+ from cryptography.hazmat.primitives import serialization
25
+
26
+ private_key_bytes = private_key_content.strip().encode()
27
+ password = private_key_passphrase.encode() if private_key_passphrase else None
28
+
29
+ private_key = serialization.load_pem_private_key(
30
+ private_key_bytes,
31
+ password=password,
32
+ backend=default_backend(),
33
+ )
34
+
35
+ return private_key.private_bytes(
36
+ encoding=serialization.Encoding.DER,
37
+ format=serialization.PrivateFormat.PKCS8,
38
+ encryption_algorithm=serialization.NoEncryption(),
39
+ )
40
+
41
+
42
+ def _get_connection(
43
+ user: str,
44
+ account: str,
45
+ database: str,
46
+ schema: str,
47
+ warehouse: str,
48
+ ) -> "snowflake.connector.SnowflakeConnection":
49
+ """Create a Snowflake connection using environment-provided credentials."""
50
+ import snowflake.connector
51
+
52
+ conn_params: dict[str, typing.Any] = {
53
+ "user": user,
54
+ "account": account,
55
+ "database": database,
56
+ "schema": schema,
57
+ "warehouse": warehouse,
58
+ }
59
+
60
+ # The secrets will be injected as environment variables.
61
+ private_key_content = os.environ.get("SNOWFLAKE_PRIVATE_KEY")
62
+ if private_key_content:
63
+ private_key_passphrase = os.environ.get("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE")
64
+ conn_params["private_key"] = _get_private_key(private_key_content, private_key_passphrase)
65
+
66
+ return snowflake.connector.connect(**conn_params)
67
+
68
+
69
+ def _write_to_sf(dataframe: DataFrame):
70
+ if not dataframe.uri:
71
+ raise ValueError("dataframe.uri cannot be None.")
72
+
73
+ from snowflake.connector.pandas_tools import write_pandas
74
+
75
+ uri = typing.cast(str, dataframe.uri)
76
+ _, user, account, warehouse, database, schema, table = re.split(PROTOCOL_SEP, uri)
77
+ df = typing.cast("pd.DataFrame", dataframe.val)
78
+
79
+ conn = _get_connection(user, account, database, schema, warehouse)
80
+ write_pandas(conn, df, table)
81
+
82
+
83
+ def _read_from_sf(
84
+ flyte_value: literals_pb2.StructuredDataset,
85
+ current_task_metadata: literals_pb2.StructuredDatasetMetadata,
86
+ ) -> "pd.DataFrame":
87
+ uri = flyte_value.uri
88
+ if not uri:
89
+ raise ValueError("flyte_value.uri cannot be empty.")
90
+
91
+ _, user, account, warehouse, database, schema, query_id = re.split(PROTOCOL_SEP, uri)
92
+
93
+ conn = _get_connection(user, account, database, schema, warehouse)
94
+ cs = conn.cursor()
95
+ cs.get_results_from_sfqid(query_id)
96
+ return cs.fetch_pandas_all()
97
+
98
+
99
+ class PandasToSnowflakeEncodingHandlers(DataFrameEncoder):
100
+ def __init__(self):
101
+ super().__init__(pd.DataFrame, SNOWFLAKE, "")
102
+
103
+ async def encode(
104
+ self,
105
+ dataframe: DataFrame,
106
+ structured_dataset_type: types_pb2.StructuredDatasetType,
107
+ ) -> literals_pb2.StructuredDataset:
108
+ _write_to_sf(dataframe)
109
+ return literals_pb2.StructuredDataset(
110
+ uri=typing.cast(str, dataframe.uri),
111
+ metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type),
112
+ )
113
+
114
+
115
+ class SnowflakeToPandasDecodingHandler(DataFrameDecoder):
116
+ def __init__(self):
117
+ super().__init__(pd.DataFrame, SNOWFLAKE, "")
118
+
119
+ async def decode(
120
+ self,
121
+ flyte_value: literals_pb2.StructuredDataset,
122
+ current_task_metadata: literals_pb2.StructuredDatasetMetadata,
123
+ ) -> "pd.DataFrame":
124
+ return _read_from_sf(flyte_value, current_task_metadata)
@@ -0,0 +1,132 @@
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional, Type
4
+
5
+ from flyte.connectors import AsyncConnectorExecutorMixin
6
+ from flyte.extend import TaskTemplate
7
+ from flyte.models import NativeInterface, SerializationContext
8
+ from flyteidl2.core import tasks_pb2
9
+
10
+
11
+ @dataclass
12
+ class SnowflakeConfig(object):
13
+ """
14
+ Configure a Snowflake Task using a `SnowflakeConfig` object.
15
+
16
+ Additional connection parameters (role, authenticator, session_parameters, etc.) can be passed
17
+ via connection_kwargs.
18
+ See: https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api
19
+
20
+ Args:
21
+ account: The Snowflake account identifier.
22
+ database: The Snowflake database name.
23
+ schema: The Snowflake schema name.
24
+ warehouse: The Snowflake warehouse name.
25
+ user: The Snowflake user name.
26
+ connection_kwargs: Optional dictionary of additional Snowflake connection parameters.
27
+ """
28
+
29
+ account: str
30
+ database: str
31
+ schema: str
32
+ warehouse: str
33
+ user: str
34
+ connection_kwargs: Optional[Dict[str, Any]] = None
35
+
36
+
37
+ class Snowflake(AsyncConnectorExecutorMixin, TaskTemplate):
38
+ _TASK_TYPE = "snowflake"
39
+
40
+ def __init__(
41
+ self,
42
+ name: str,
43
+ query_template: str,
44
+ plugin_config: SnowflakeConfig,
45
+ inputs: Optional[Dict[str, Type]] = None,
46
+ output_dataframe_type: Optional[Type] = None,
47
+ secret_group: Optional[str] = None,
48
+ snowflake_private_key: Optional[str] = None,
49
+ snowflake_private_key_passphrase: Optional[str] = None,
50
+ batch: bool = False,
51
+ **kwargs,
52
+ ):
53
+ """
54
+ Task to run parameterized SQL queries against Snowflake.
55
+
56
+ Args:
57
+ name: The name of this task.
58
+ query_template: The actual query to run. This can be parameterized using Python's
59
+ printf-style string formatting with named parameters (e.g. %(param_name)s).
60
+ plugin_config: `SnowflakeConfig` object containing connection metadata.
61
+ inputs: Name and type of inputs specified as a dictionary.
62
+ output_dataframe_type: If some data is produced by this query, then you can specify the
63
+ output dataframe type.
64
+ secret_group: Optional group for secrets in the secret store. The environment variable
65
+ name is auto-generated from ``{secret_group}_{key}``, uppercased with hyphens
66
+ replaced by underscores. If omitted, the key alone is used.
67
+ snowflake_private_key: The secret key for the Snowflake private key (key-pair auth).
68
+ snowflake_private_key_passphrase: The secret key for the private key passphrase
69
+ (if encrypted).
70
+ batch: When True, list inputs are expanded into a multi-row VALUES clause. The
71
+ query_template should contain a single ``VALUES (%(col)s, ...)`` placeholder
72
+ and each input should be a list of equal length.
73
+
74
+ Note: For password authentication or other auth methods, pass them via `connection_kwargs`.
75
+ """
76
+ outputs = None
77
+ if output_dataframe_type is not None:
78
+ outputs = {"results": output_dataframe_type}
79
+
80
+ super().__init__(
81
+ name=name,
82
+ interface=NativeInterface(
83
+ {k: (v, None) for k, v in inputs.items()} if inputs else {},
84
+ outputs or {},
85
+ ),
86
+ task_type=self._TASK_TYPE,
87
+ image=None,
88
+ **kwargs,
89
+ )
90
+
91
+ self.output_dataframe_type = output_dataframe_type
92
+ self.plugin_config = plugin_config
93
+ self.query_template = re.sub(r"\s+", " ", query_template.replace("\n", " ").replace("\t", " ")).strip()
94
+ self.batch = batch
95
+ self.secret_group = secret_group
96
+ self.snowflake_private_key = snowflake_private_key
97
+ self.snowflake_private_key_passphrase = snowflake_private_key_passphrase
98
+
99
+ def _to_env_var(self, key: str) -> str:
100
+ """Generate an environment variable name from the secret group and key."""
101
+ env_var = f"{self.secret_group}_{key}" if self.secret_group else key
102
+ return env_var.replace("-", "_").upper()
103
+
104
+ def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]:
105
+ config = {
106
+ "account": self.plugin_config.account,
107
+ "database": self.plugin_config.database,
108
+ "schema": self.plugin_config.schema,
109
+ "warehouse": self.plugin_config.warehouse,
110
+ "user": self.plugin_config.user,
111
+ }
112
+
113
+ if self.batch:
114
+ config["batch"] = True
115
+
116
+ # Add additional connection parameters
117
+ if self.plugin_config.connection_kwargs:
118
+ config["connection_kwargs"] = self.plugin_config.connection_kwargs
119
+
120
+ secrets = {}
121
+ if self.snowflake_private_key is not None:
122
+ secrets["snowflake_private_key"] = self._to_env_var(self.snowflake_private_key)
123
+ if self.snowflake_private_key_passphrase is not None:
124
+ secrets["snowflake_private_key_passphrase"] = self._to_env_var(self.snowflake_private_key_passphrase)
125
+ if secrets:
126
+ config["secrets"] = secrets
127
+
128
+ return config
129
+
130
+ def sql(self, sctx: SerializationContext) -> Optional[str]:
131
+ sql = tasks_pb2.Sql(statement=self.query_template, dialect=tasks_pb2.Sql.Dialect.ANSI)
132
+ return sql
@@ -0,0 +1,178 @@
1
+ Metadata-Version: 2.4
2
+ Name: flyteplugins-snowflake
3
+ Version: 2.0.0
4
+ Summary: Snowflake plugin for flyte
5
+ Author-email: Kevin Su <pingsutw@users.noreply.github.com>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: flyte[connector]
9
+ Requires-Dist: snowflake-connector-python[pandas]
10
+ Requires-Dist: cryptography
11
+
12
+ # Snowflake Plugin for Flyte
13
+
14
+ Run Snowflake SQL queries as Flyte tasks with parameterized inputs, key-pair authentication, batch inserts, and DataFrame support.
15
+
16
+ ## Installation
17
+
18
+ ```bash
19
+ pip install flyteplugins-snowflake
20
+ ```
21
+
22
+ ## Quick start
23
+
24
+ ```python
25
+ from flyteplugins.snowflake import Snowflake, SnowflakeConfig
26
+
27
+ import flyte
28
+
29
+ config = SnowflakeConfig(
30
+ account="myorg-myaccount",
31
+ user="flyte_user",
32
+ database="ANALYTICS",
33
+ schema="PUBLIC",
34
+ warehouse="COMPUTE_WH",
35
+ )
36
+
37
+ query = Snowflake(
38
+ name="count_users",
39
+ query_template="SELECT COUNT(*) FROM users",
40
+ plugin_config=config,
41
+ snowflake_private_key="snowflake-pk",
42
+ )
43
+ ```
44
+
45
+ ## Authentication
46
+
47
+ The plugin supports Snowflake [key-pair authentication](https://docs.snowflake.com/en/user-guide/key-pair-auth). Pass secret keys via `snowflake_private_key` (and optionally `snowflake_private_key_passphrase`).
48
+
49
+ ```python
50
+ task = Snowflake(
51
+ name="my_task",
52
+ query_template="SELECT 1",
53
+ plugin_config=config,
54
+ snowflake_private_key="private-key",
55
+ snowflake_private_key_passphrase="passphrase",
56
+ # Generates env vars: PRIVATE_KEY, PASSPHRASE
57
+ )
58
+ ```
59
+
60
+ For other auth methods (password, OAuth, etc.), pass them via `connection_kwargs`:
61
+
62
+ ```python
63
+ config = SnowflakeConfig(
64
+ account="myorg-myaccount",
65
+ user="flyte_user",
66
+ database="ANALYTICS",
67
+ schema="PUBLIC",
68
+ warehouse="COMPUTE_WH",
69
+ connection_kwargs={"password": "...", "role": "ADMIN"},
70
+ )
71
+ ```
72
+
73
+ ## Parameterized queries
74
+
75
+ Use `%(name)s` placeholders and typed `inputs`:
76
+
77
+ ```python
78
+ lookup = Snowflake(
79
+ name="lookup_user",
80
+ query_template="SELECT * FROM users WHERE id = %(user_id)s",
81
+ plugin_config=config,
82
+ inputs={"user_id": int},
83
+ output_dataframe_type=pd.DataFrame,
84
+ snowflake_private_key="snowflake-pk",
85
+ )
86
+ ```
87
+
88
+ ## Batch inserts
89
+
90
+ Set `batch=True` to expand list inputs into multi-row `VALUES` clauses:
91
+
92
+ ```python
93
+ insert_rows = Snowflake(
94
+ name="insert_users",
95
+ query_template="INSERT INTO users (id, name, age) VALUES (%(id)s, %(name)s, %(age)s)",
96
+ plugin_config=config,
97
+ inputs={"id": list[int], "name": list[str], "age": list[int]},
98
+ snowflake_private_key="snowflake-pk",
99
+ batch=True,
100
+ )
101
+
102
+ # Calling with id=[1,2], name=["Alice","Bob"], age=[30,25] expands to:
103
+ # INSERT INTO users (id, name, age) VALUES (%(id_0)s, %(name_0)s, %(age_0)s), (%(id_1)s, %(name_1)s, %(age_1)s)
104
+ ```
105
+
106
+ ## Reading results as DataFrames
107
+
108
+ Set `output_dataframe_type` to get query results as a pandas DataFrame:
109
+
110
+ ```python
111
+ import pandas as pd
112
+
113
+ select_task = Snowflake(
114
+ name="get_users",
115
+ query_template="SELECT * FROM users",
116
+ plugin_config=config,
117
+ output_dataframe_type=pd.DataFrame,
118
+ snowflake_private_key="snowflake-pk",
119
+ )
120
+ ```
121
+
122
+ ## Full example
123
+
124
+ ```python
125
+ import pandas as pd
126
+ from flyteplugins.snowflake import Snowflake, SnowflakeConfig
127
+
128
+ import flyte
129
+
130
+ config = SnowflakeConfig(
131
+ user="KEVIN",
132
+ account="PWGJLTH-XKB21544",
133
+ database="FLYTE",
134
+ schema="PUBLIC",
135
+ warehouse="COMPUTE_WH",
136
+ )
137
+
138
+ insert_task = Snowflake(
139
+ name="insert_rows",
140
+ inputs={"id": list[int], "name": list[str], "age": list[int]},
141
+ plugin_config=config,
142
+ query_template="INSERT INTO FLYTE.PUBLIC.TEST (ID, NAME, AGE) VALUES (%(id)s, %(name)s, %(age)s)",
143
+ snowflake_private_key="snowflake",
144
+ batch=True,
145
+ )
146
+
147
+ select_task = Snowflake(
148
+ name="select_all",
149
+ output_dataframe_type=pd.DataFrame,
150
+ plugin_config=config,
151
+ query_template="SELECT * FROM FLYTE.PUBLIC.TEST",
152
+ snowflake_private_key="snowflake",
153
+ )
154
+
155
+ snowflake_env = flyte.TaskEnvironment.from_task("snowflake_env", insert_task, select_task)
156
+
157
+ env = flyte.TaskEnvironment(
158
+ name="example_env",
159
+ image=flyte.Image.from_debian_base().with_pip_packages("flyteplugins-snowflake"),
160
+ secrets=[flyte.Secret(key="snowflake", as_env_var="SNOWFLAKE_PRIVATE_KEY")],
161
+ depends_on=[snowflake_env],
162
+ )
163
+
164
+
165
+ @env.task
166
+ async def main(ids: list[int], names: list[str], ages: list[int]) -> float:
167
+ await insert_task(id=ids, name=names, age=ages)
168
+ df = await select_task()
169
+ return df["AGE"].mean().item()
170
+
171
+
172
+ if __name__ == "__main__":
173
+ flyte.init_from_config()
174
+ run = flyte.with_runcontext(mode="remote").run(
175
+ main, ids=[123, 456], names=["Kevin", "Alice"], ages=[30, 25],
176
+ )
177
+ print(run.url)
178
+ ```
@@ -0,0 +1,9 @@
1
+ flyteplugins/snowflake/__init__.py,sha256=pduVR0IVn0gejuQoP3KjkUfO4OGypHocm4lFh_vnZXQ,1673
2
+ flyteplugins/snowflake/connector.py,sha256=ahFjYUVuYYvgW41pQ86spHQFP7HIAMa_PdkJvu4UkeI,14292
3
+ flyteplugins/snowflake/dataframe.py,sha256=yUkRVCGWZ0IqbK9CIaKsp-MkwshJdbJPd49jWlzM7BM,3995
4
+ flyteplugins/snowflake/task.py,sha256=jDdYMjD5F7JZGOu2Mp5ax1kqV3R7wz6BIVbw7-KAaGE,5416
5
+ flyteplugins_snowflake-2.0.0.dist-info/METADATA,sha256=g77dxtEKw9fcSlWlkolw9Imacs9G41eWjBKoOsVFvZ0,4603
6
+ flyteplugins_snowflake-2.0.0.dist-info/WHEEL,sha256=YCfwYGOYMi5Jhw2fU4yNgwErybb2IX5PEwBKV4ZbdBo,91
7
+ flyteplugins_snowflake-2.0.0.dist-info/entry_points.txt,sha256=JwU7E683aUhA8NdHlUkF3RioiWf7aHE-c3mMgkUP4qM,83
8
+ flyteplugins_snowflake-2.0.0.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
9
+ flyteplugins_snowflake-2.0.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [flyte.connectors]
2
+ snowflake = flyteplugins.snowflake.connector:SnowflakeConnector
@@ -0,0 +1 @@
1
+ flyteplugins