flyteplugins-snowflake 2.0.0b54__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,4 @@
1
+ from flyteplugins.snowflake.connector import SnowflakeConnector
2
+ from flyteplugins.snowflake.task import Snowflake, SnowflakeConfig
3
+
4
+ __all__ = ["Snowflake", "SnowflakeConfig", "SnowflakeConnector"]
@@ -0,0 +1,271 @@
1
+ import asyncio
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional
4
+
5
+ from async_lru import alru_cache
6
+ from flyte import logger
7
+ from flyte.connectors import AsyncConnector, ConnectorRegistry, Resource, ResourceMeta
8
+ from flyte.connectors.utils import convert_to_flyte_phase
9
+ from flyte.io import DataFrame
10
+ from flyteidl2.core.execution_pb2 import TaskExecution, TaskLog
11
+ from flyteidl2.core.tasks_pb2 import TaskTemplate
12
+ from google.protobuf import json_format
13
+ from snowflake import connector
14
+
15
+ TASK_TYPE = "snowflake"
16
+
17
+
18
+ @dataclass
19
+ class SnowflakeJobMetadata(ResourceMeta):
20
+ account: str
21
+ user: str
22
+ database: str
23
+ schema: str
24
+ warehouse: str
25
+ query_id: str
26
+ has_output: bool
27
+ connection_kwargs: Optional[Dict[str, Any]] = None
28
+
29
+
30
+ def _get_private_key(private_key_content: str, private_key_passphrase: Optional[str] = None) -> bytes:
31
+ """
32
+ Decode the private key from the secret and return it in DER format.
33
+ """
34
+ from cryptography.hazmat.backends import default_backend
35
+ from cryptography.hazmat.primitives import serialization
36
+
37
+ private_key_bytes = private_key_content.strip().encode()
38
+ password = private_key_passphrase.encode() if private_key_passphrase else None
39
+
40
+ private_key = serialization.load_pem_private_key(
41
+ private_key_bytes,
42
+ password=password,
43
+ backend=default_backend(),
44
+ )
45
+
46
+ return private_key.private_bytes(
47
+ encoding=serialization.Encoding.DER,
48
+ format=serialization.PrivateFormat.PKCS8,
49
+ encryption_algorithm=serialization.NoEncryption(),
50
+ )
51
+
52
+
53
+ @alru_cache
54
+ async def _get_snowflake_connection(
55
+ account: str,
56
+ user: str,
57
+ database: str,
58
+ schema: str,
59
+ warehouse: str,
60
+ private_key_content: Optional[str] = None,
61
+ private_key_passphrase: Optional[str] = None,
62
+ **connection_kwargs,
63
+ ) -> connector.SnowflakeConnection:
64
+ """
65
+ Create and return a Snowflake connection.
66
+
67
+ Supports private key authentication (recommended) and other auth methods via connection_kwargs.
68
+ See: https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api
69
+ """
70
+
71
+ def _create_connection():
72
+ connection_params = {
73
+ "account": account,
74
+ "user": user,
75
+ "database": database,
76
+ "schema": schema,
77
+ "warehouse": warehouse,
78
+ **connection_kwargs,
79
+ }
80
+
81
+ # Add private key authentication if provided
82
+ if private_key_content:
83
+ private_key = _get_private_key(private_key_content, private_key_passphrase)
84
+ connection_params["private_key"] = private_key
85
+
86
+ # Let Snowflake connector validate authentication requirements
87
+ return connector.connect(**connection_params)
88
+
89
+ loop = asyncio.get_running_loop()
90
+ return await loop.run_in_executor(None, _create_connection)
91
+
92
+
93
+ def _construct_query_link(account: str, query_id: str) -> str:
94
+ """Construct a Snowflake console link for the query."""
95
+ if "-" in account:
96
+ parts = account.split("-", 1)
97
+ if len(parts) == 2:
98
+ org_name, account_name = parts
99
+ base_url = f"https://app.snowflake.com/{org_name}/{account_name}"
100
+ else:
101
+ base_url = f"https://app.snowflake.com/{account}"
102
+ else:
103
+ # Simple account locator
104
+ base_url = f"https://app.snowflake.com/{account}"
105
+
106
+ return f"{base_url}/#/compute/history/queries/{query_id}/detail"
107
+
108
+
109
+ class SnowflakeConnector(AsyncConnector):
110
+ name = "Snowflake Connector"
111
+ task_type_name = TASK_TYPE
112
+ metadata_type = SnowflakeJobMetadata
113
+
114
+ async def create(
115
+ self,
116
+ task_template: TaskTemplate,
117
+ inputs: Optional[Dict[str, Any]] = None,
118
+ snowflake_private_key: Optional[str] = None,
119
+ snowflake_private_key_passphrase: Optional[str] = None,
120
+ **kwargs,
121
+ ) -> SnowflakeJobMetadata:
122
+ """
123
+ Submit a query to Snowflake asynchronously.
124
+ """
125
+ custom = json_format.MessageToDict(task_template.custom) if task_template.custom else {}
126
+
127
+ account = custom.get("account")
128
+ if not account:
129
+ raise ValueError("Missing Snowflake account. Please set it through task configuration.")
130
+
131
+ user = custom.get("user")
132
+ database = custom.get("database")
133
+ schema = custom.get("schema", "PUBLIC")
134
+ warehouse = custom.get("warehouse")
135
+
136
+ if not all([user, database, warehouse]):
137
+ raise ValueError("User, database and warehouse must be specified in the task configuration.")
138
+
139
+ # Get additional connection parameters from custom config
140
+ connection_kwargs = custom.get("connection_kwargs", {})
141
+
142
+ conn = await _get_snowflake_connection(
143
+ account=account,
144
+ user=user,
145
+ database=database,
146
+ schema=schema,
147
+ warehouse=warehouse,
148
+ private_key_content=snowflake_private_key,
149
+ private_key_passphrase=snowflake_private_key_passphrase,
150
+ **connection_kwargs,
151
+ )
152
+
153
+ query = task_template.sql.statement
154
+
155
+ def _execute_query():
156
+ cursor = conn.cursor()
157
+
158
+ cursor.execute_async(query, inputs)
159
+ query_id = cursor.sfqid
160
+ cursor.close()
161
+ return query_id
162
+
163
+ loop = asyncio.get_running_loop()
164
+ query_id = await loop.run_in_executor(None, _execute_query)
165
+
166
+ logger.info(f"Snowflake query submitted with ID: {query_id}")
167
+
168
+ return SnowflakeJobMetadata(
169
+ account=account,
170
+ user=user,
171
+ database=database,
172
+ schema=schema,
173
+ warehouse=warehouse,
174
+ query_id=query_id,
175
+ has_output=task_template.interface.outputs is not None
176
+ and len(task_template.interface.outputs.variables) > 0,
177
+ connection_kwargs=connection_kwargs,
178
+ )
179
+
180
+ async def get(
181
+ self,
182
+ resource_meta: SnowflakeJobMetadata,
183
+ snowflake_private_key: Optional[str] = None,
184
+ snowflake_private_key_passphrase: Optional[str] = None,
185
+ **kwargs,
186
+ ) -> Resource:
187
+ """
188
+ Poll the status of a Snowflake query.
189
+ """
190
+ conn = await _get_snowflake_connection(
191
+ account=resource_meta.account,
192
+ user=resource_meta.user,
193
+ database=resource_meta.database,
194
+ schema=resource_meta.schema,
195
+ warehouse=resource_meta.warehouse,
196
+ private_key_content=snowflake_private_key,
197
+ private_key_passphrase=snowflake_private_key_passphrase,
198
+ **(resource_meta.connection_kwargs or {}),
199
+ )
200
+
201
+ log_link = TaskLog(
202
+ uri=_construct_query_link(resource_meta.account, resource_meta.query_id),
203
+ name="Snowflake Dashboard",
204
+ ready=True,
205
+ link_type=TaskLog.DASHBOARD,
206
+ )
207
+
208
+ def _get_query_status():
209
+ try:
210
+ status = conn.get_query_status_throw_if_error(resource_meta.query_id)
211
+ return status, None
212
+ except Exception as e:
213
+ return None, str(e)
214
+
215
+ loop = asyncio.get_running_loop()
216
+ status, error = await loop.run_in_executor(None, _get_query_status)
217
+
218
+ if error:
219
+ logger.error(f"Snowflake query failed: {error}")
220
+ return Resource(phase=TaskExecution.FAILED, message=error, log_links=[log_link])
221
+
222
+ # Map Snowflake status to Flyte phase
223
+ # Snowflake statuses: RUNNING, SUCCESS, FAILED_WITH_ERROR, ABORTING, etc.
224
+ cur_phase = convert_to_flyte_phase(status.name)
225
+ outputs = None
226
+
227
+ if cur_phase == TaskExecution.SUCCEEDED and resource_meta.has_output:
228
+ # Construct the output URI for the results
229
+ output_location = (
230
+ f"snowflake://{resource_meta.account}/{resource_meta.database}/"
231
+ f"{resource_meta.schema}/{resource_meta.query_id}"
232
+ )
233
+ outputs = {"results": DataFrame(uri=output_location)}
234
+
235
+ return Resource(phase=cur_phase, message=status.name, log_links=[log_link], outputs=outputs)
236
+
237
+ async def delete(
238
+ self,
239
+ resource_meta: SnowflakeJobMetadata,
240
+ snowflake_private_key: Optional[str] = None,
241
+ snowflake_private_key_passphrase: Optional[str] = None,
242
+ **kwargs,
243
+ ):
244
+ """
245
+ Cancel a running Snowflake query.
246
+ """
247
+ conn = await _get_snowflake_connection(
248
+ account=resource_meta.account,
249
+ user=resource_meta.user,
250
+ database=resource_meta.database,
251
+ schema=resource_meta.schema,
252
+ warehouse=resource_meta.warehouse,
253
+ private_key_content=snowflake_private_key,
254
+ private_key_passphrase=snowflake_private_key_passphrase,
255
+ **(resource_meta.connection_kwargs or {}),
256
+ )
257
+
258
+ def _cancel_query():
259
+ cursor = conn.cursor()
260
+ try:
261
+ cursor.execute(f"SELECT SYSTEM$CANCEL_QUERY('{resource_meta.query_id}')")
262
+ finally:
263
+ cursor.close()
264
+ conn.close()
265
+
266
+ loop = asyncio.get_running_loop()
267
+ await loop.run_in_executor(None, _cancel_query)
268
+ logger.info(f"Snowflake query {resource_meta.query_id} cancelled")
269
+
270
+
271
+ ConnectorRegistry.register(SnowflakeConnector())
@@ -0,0 +1,103 @@
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional, Type
4
+
5
+ from flyte.connectors import AsyncConnectorExecutorMixin
6
+ from flyte.extend import TaskTemplate
7
+ from flyte.io import DataFrame
8
+ from flyte.models import NativeInterface, SerializationContext
9
+ from flyteidl2.core import tasks_pb2
10
+
11
+
12
+ @dataclass
13
+ class SnowflakeConfig(object):
14
+ """
15
+ SnowflakeConfig should be used to configure a Snowflake Task.
16
+
17
+ Additional connection parameters (role, authenticator, session_parameters, etc.) can be passed
18
+ via connection_kwargs.
19
+ See: https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api
20
+ """
21
+
22
+ account: str
23
+ database: str
24
+ schema: str
25
+ warehouse: str
26
+ user: str
27
+ connection_kwargs: Optional[Dict[str, Any]] = None
28
+
29
+
30
+ class Snowflake(AsyncConnectorExecutorMixin, TaskTemplate):
31
+ _TASK_TYPE = "snowflake"
32
+
33
+ def __init__(
34
+ self,
35
+ name: str,
36
+ query_template: str,
37
+ plugin_config: SnowflakeConfig,
38
+ inputs: Optional[Dict[str, Type]] = None,
39
+ output_dataframe_type: Optional[Type[DataFrame]] = None,
40
+ snowflake_private_key: Optional[str] = None,
41
+ snowflake_private_key_passphrase: Optional[str] = None,
42
+ **kwargs,
43
+ ):
44
+ """
45
+ To be used to query Snowflake databases.
46
+
47
+ :param name: The name of this task, should be unique in the project
48
+ :param query_template: The actual query to run. We use Flyte's Golang templating format for query templating.
49
+ :param plugin_config: SnowflakeConfig object (includes connection_kwargs for additional parameters)
50
+ :param inputs: Name and type of inputs specified as an ordered dictionary
51
+ :param output_dataframe_type: If some data is produced by this query, then you can specify the
52
+ output dataframe type.
53
+ :param snowflake_private_key: The name of the secret containing the Snowflake private key for key-pair auth.
54
+ :param snowflake_private_key_passphrase: The name of the secret containing the private key passphrase
55
+ (if encrypted).
56
+
57
+ Note: For password authentication or other auth methods, pass them via plugin_config.connection_kwargs.
58
+ """
59
+ outputs = None
60
+ if output_dataframe_type is not None:
61
+ outputs = {"results": output_dataframe_type}
62
+ super().__init__(
63
+ name=name,
64
+ interface=NativeInterface(
65
+ {k: (v, None) for k, v in inputs.items()} if inputs else {},
66
+ outputs or {},
67
+ ),
68
+ task_type=self._TASK_TYPE,
69
+ image=None,
70
+ **kwargs,
71
+ )
72
+ self.output_dataframe_type = output_dataframe_type
73
+ self.plugin_config = plugin_config
74
+ self.query_template = re.sub(r"\s+", " ", query_template.replace("\n", " ").replace("\t", " ")).strip()
75
+ self.snowflake_private_key = snowflake_private_key
76
+ self.snowflake_private_key_passphrase = snowflake_private_key_passphrase
77
+
78
+ def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]:
79
+ config = {
80
+ "account": self.plugin_config.account,
81
+ "database": self.plugin_config.database,
82
+ "schema": self.plugin_config.schema,
83
+ "warehouse": self.plugin_config.warehouse,
84
+ "user": self.plugin_config.user,
85
+ }
86
+
87
+ # Add additional connection parameters
88
+ if self.plugin_config.connection_kwargs:
89
+ config["connection_kwargs"] = self.plugin_config.connection_kwargs
90
+
91
+ secrets = {}
92
+ if self.snowflake_private_key is not None:
93
+ secrets["snowflake_private_key"] = self.snowflake_private_key
94
+ if self.snowflake_private_key_passphrase is not None:
95
+ secrets["snowflake_private_key_passphrase"] = self.snowflake_private_key_passphrase
96
+ if secrets:
97
+ config["secrets"] = secrets
98
+
99
+ return config
100
+
101
+ def sql(self, sctx: SerializationContext) -> Optional[str]:
102
+ sql = tasks_pb2.Sql(statement=self.query_template, dialect=tasks_pb2.Sql.Dialect.ANSI)
103
+ return sql
@@ -0,0 +1,42 @@
1
+ Metadata-Version: 2.4
2
+ Name: flyteplugins-snowflake
3
+ Version: 2.0.0b54
4
+ Summary: Snowflake plugin for flyte
5
+ Author-email: Kevin Su <pingsutw@users.noreply.github.com>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: flyte[connector]
9
+ Requires-Dist: snowflake-connector-python
10
+ Requires-Dist: cryptography
11
+
12
+ # Snowflake Plugin for Flyte
13
+
14
+ This plugin provides Snowflake integration for Flyte, enabling you to run Snowflake queries as Flyte tasks.
15
+
16
+ ## Installation
17
+
18
+ ```bash
19
+ pip install flyteplugins-snowflake
20
+ ```
21
+
22
+ ## Usage
23
+
24
+ ```python
25
+ from flyteplugins.snowflake import Snowflake, SnowflakeConfig
26
+
27
+ config = SnowflakeConfig(
28
+ account="myaccount.us-east-1",
29
+ user="myuser",
30
+ database="mydb",
31
+ schema="PUBLIC",
32
+ warehouse="mywarehouse",
33
+ )
34
+
35
+ task = Snowflake(
36
+ name="my_query",
37
+ query_template="INSERT INTO FLYTE.PUBLIC.TEST (ID, NAME, AGE) VALUES (%(id)s, %(name)s, %(age)s);",
38
+ plugin_config=config,
39
+ inputs={"id": int, "name": str, "age": int},
40
+ snowflake_private_key="snowflake-private-key-secret",
41
+ )
42
+ ```
@@ -0,0 +1,9 @@
1
+ flyteplugins/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ flyteplugins/snowflake/__init__.py,sha256=o44Xb0MqEYmjLDx-lWR5JHyAX16ReCHTW1andZiGJ5U,197
3
+ flyteplugins/snowflake/connector.py,sha256=-PTgNflQcTLn4qVP4b61aoyspTjIah76lRFpGGUHamg,9308
4
+ flyteplugins/snowflake/task.py,sha256=KbOA9wdQ60h1vrlhb30SBy9b8QJjwtQ1cDCJ_QQ0Aho,4129
5
+ flyteplugins_snowflake-2.0.0b54.dist-info/METADATA,sha256=_juIGsQVQKe7me65vRc2oZluL8OEBKX4lSoJTaQmZhM,1054
6
+ flyteplugins_snowflake-2.0.0b54.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
7
+ flyteplugins_snowflake-2.0.0b54.dist-info/entry_points.txt,sha256=JwU7E683aUhA8NdHlUkF3RioiWf7aHE-c3mMgkUP4qM,83
8
+ flyteplugins_snowflake-2.0.0b54.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
9
+ flyteplugins_snowflake-2.0.0b54.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.10.2)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [flyte.connectors]
2
+ snowflake = flyteplugins.snowflake.connector:SnowflakeConnector
@@ -0,0 +1 @@
1
+ flyteplugins