soda-snowflake 4.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ Metadata-Version: 2.4
2
+ Name: soda-snowflake
3
+ Version: 4.0.5
4
+ Requires-Dist: soda-core==4.0.5
5
+ Requires-Dist: snowflake-connector-python>=3.0
6
+ Dynamic: requires-dist
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,24 @@
1
+ #!/usr/bin/env python
2
+
3
+ from setuptools import setup
4
+
5
+ package_name = "soda-snowflake"
6
+ package_version = "4.0.5"
7
+ description = "Soda Snowflake V4"
8
+
9
+ requires = [
10
+ f"soda-core=={package_version}",
11
+ "snowflake-connector-python>=3.0",
12
+ ] # 4.0 is also fine, but for backwards compatibility we include 3.0
13
+
14
+ setup(
15
+ name=package_name,
16
+ version=package_version,
17
+ install_requires=requires,
18
+ package_dir={"": "src"},
19
+ entry_points={
20
+ "soda.plugins.data_source.snowflake": [
21
+ "SnowflakeDataSourceImpl = soda_snowflake.common.data_sources.snowflake_data_source:SnowflakeDataSourceImpl",
22
+ ],
23
+ },
24
+ )
@@ -0,0 +1,200 @@
1
+ from logging import Logger
2
+ from numbers import Number
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ from soda_core.common.data_source_connection import DataSourceConnection
6
+ from soda_core.common.data_source_impl import DataSourceImpl
7
+ from soda_core.common.logging_constants import soda_logger
8
+ from soda_core.common.metadata_types import SodaDataTypeName
9
+ from soda_core.common.soda_cloud_dto import SamplerType
10
+ from soda_core.common.sql_ast import COLUMN, COUNT, DISTINCT, TUPLE, VALUES
11
+ from soda_core.common.sql_dialect import SqlDialect
12
+ from soda_core.contracts.impl.contract_verification_impl import ContractImpl
13
+ from soda_snowflake.common.data_sources.snowflake_data_source_connection import (
14
+ SnowflakeDataSource as SnowflakeDataSourceModel,
15
+ )
16
+ from soda_snowflake.common.data_sources.snowflake_data_source_connection import (
17
+ SnowflakeDataSourceConnection,
18
+ )
19
+
20
+ logger: Logger = soda_logger
21
+
22
+ if TYPE_CHECKING:
23
+ from soda_core.contracts.impl.contract_verification_impl import ContractImpl
24
+
25
+ TIMESTAMP_WITHOUT_TIME_ZONE = "timestamp without time zone"
26
+ TIMESTAMP_WITH_TIME_ZONE = "timestamp with time zone"
27
+ TIMESTAMP_WITH_LOCAL_TIME_ZONE = "timestamp with local time zone"
28
+
29
+
30
+ class SnowflakeDataSourceImpl(DataSourceImpl, model_class=SnowflakeDataSourceModel):
31
+ def __init__(self, data_source_model: SnowflakeDataSourceModel, connection: Optional[DataSourceConnection] = None):
32
+ super().__init__(data_source_model=data_source_model, connection=connection)
33
+
34
+ def _create_sql_dialect(self) -> SqlDialect:
35
+ return SnowflakeSqlDialect(data_source_impl=self)
36
+
37
+ def _create_data_source_connection(self) -> DataSourceConnection:
38
+ return SnowflakeDataSourceConnection(
39
+ name=self.data_source_model.name, connection_properties=self.data_source_model.connection_properties
40
+ )
41
+
42
+ def switch_warehouse(self, warehouse: str, contract_impl: ContractImpl) -> None:
43
+ if warehouse and contract_impl.datasource_warehouse != warehouse:
44
+ if contract_impl.datasource_warehouse is None:
45
+ logger.info(
46
+ f"Setting warehouse to '{warehouse}' for Contract verification of dataset '{contract_impl.dataset_identifier.to_string()}'"
47
+ )
48
+ else:
49
+ logger.info(
50
+ f"Switching warehouse from '{contract_impl.datasource_warehouse}' to '{warehouse}' for Contract verification of dataset '{contract_impl.dataset_identifier.to_string()}'"
51
+ )
52
+ self._execute_switch_warehouse(warehouse)
53
+ else:
54
+ logger.info(
55
+ f"Using warehouse '{contract_impl.datasource_warehouse}' for Contract verification of dataset '{contract_impl.dataset_identifier.to_string()}'"
56
+ )
57
+
58
+ def _execute_switch_warehouse(self, warehouse: str) -> None:
59
+ switch_warehouse_sql = f"USE WAREHOUSE {warehouse}"
60
+ self.execute_query(switch_warehouse_sql)
61
+
62
+ def get_current_warehouse(self) -> Optional[str]:
63
+ sql = "SELECT CURRENT_WAREHOUSE()"
64
+ result = self.execute_query(sql)
65
+ result_rows = result.rows
66
+ row = result_rows[0] if result_rows else None
67
+
68
+ return row[0] if row and row[0] else None
69
+
70
+
71
+ class SnowflakeSqlDialect(SqlDialect):
72
+ SODA_DATA_TYPE_SYNONYMS = (
73
+ (SodaDataTypeName.TEXT, SodaDataTypeName.VARCHAR, SodaDataTypeName.CHAR),
74
+ (
75
+ SodaDataTypeName.NUMERIC,
76
+ SodaDataTypeName.DECIMAL,
77
+ SodaDataTypeName.INTEGER,
78
+ SodaDataTypeName.BIGINT,
79
+ SodaDataTypeName.SMALLINT,
80
+ ),
81
+ (SodaDataTypeName.FLOAT, SodaDataTypeName.DOUBLE),
82
+ )
83
+
84
+ def default_casify(self, identifier: str) -> str:
85
+ return identifier.upper()
86
+
87
+ def metadata_casify(self, identifier: str) -> str:
88
+ """Metadata identifiers are not uppercased for Snowflake."""
89
+ return identifier
90
+
91
+ def create_table_casify_qualified_name(self, qualified_name: str) -> str:
92
+ # Parse the last element from the qualified name and casify it
93
+ if "." in qualified_name: # It is an actual qualified name, not just a table name
94
+ name_parts = qualified_name.split(".")
95
+ # Caseify the last element
96
+ name_parts[-1] = name_parts[-1].upper()
97
+ return ".".join(name_parts)
98
+ # It is not a fully qualified name, just a table name
99
+ return qualified_name.upper()
100
+
101
+ def build_cte_values_sql(self, values: VALUES, alias_columns: list[COLUMN] | None) -> str:
102
+ return " SELECT * FROM VALUES\n" + ",\n".join([self.build_expression_sql(value) for value in values.values])
103
+
104
+ def _build_tuple_sql(self, tuple: TUPLE) -> str:
105
+ if tuple.check_context(COUNT) and tuple.check_context(DISTINCT):
106
+ return self._build_tuple_sql_in_distinct(tuple)
107
+ return f"{super()._build_tuple_sql(tuple)}"
108
+
109
+ def _build_tuple_sql_in_distinct(self, tuple: TUPLE) -> str:
110
+ return f"ARRAY_CONSTRUCT{super()._build_tuple_sql(tuple)}"
111
+
112
+ def _get_data_type_name_synonyms(self) -> list[list[str]]:
113
+ # Implements data type synonyms
114
+ # Each list should represent a list of synonyms
115
+ return [
116
+ ["varchar", "text", "string"],
117
+ ["number", "decimal", "numeric", "int", "integer", "bigint", "smallint", "tinyint", "byteint"],
118
+ ["float", "float4", "float8", "double", "double precision", "real"],
119
+ ["timestamp", "datetime", "timestamp_ntz", TIMESTAMP_WITHOUT_TIME_ZONE],
120
+ ["timestamp_ltz", TIMESTAMP_WITH_LOCAL_TIME_ZONE],
121
+ ["timestamp_tz", TIMESTAMP_WITH_TIME_ZONE],
122
+ ]
123
+
124
+ def get_data_source_data_type_name_by_soda_data_type_names(self) -> dict:
125
+ """
126
+ Maps DBDataType names to data source type names.
127
+ """
128
+ return {
129
+ SodaDataTypeName.CHAR: "char", # Note: by default a char is 1 byte in Snowflake!
130
+ SodaDataTypeName.VARCHAR: "varchar",
131
+ SodaDataTypeName.TEXT: "text", # alias for varchar
132
+ SodaDataTypeName.SMALLINT: "smallint",
133
+ SodaDataTypeName.INTEGER: "integer",
134
+ SodaDataTypeName.BIGINT: "bigint",
135
+ SodaDataTypeName.DECIMAL: "number", # decimal & numeric → number
136
+ SodaDataTypeName.NUMERIC: "number",
137
+ SodaDataTypeName.FLOAT: "float", # float / double → float
138
+ SodaDataTypeName.DOUBLE: "float",
139
+ SodaDataTypeName.TIMESTAMP: "timestamp_ntz", # default timestamp in snowflake
140
+ SodaDataTypeName.TIMESTAMP_TZ: "timestamp_tz",
141
+ SodaDataTypeName.DATE: "date",
142
+ SodaDataTypeName.TIME: "time",
143
+ SodaDataTypeName.BOOLEAN: "boolean",
144
+ }
145
+
146
+ def data_type_has_parameter_character_maximum_length(self, data_type_name) -> bool:
147
+ return data_type_name.lower() in ["varchar", "char", "character", "text"]
148
+
149
+ # TODO: test this thorough. The code here is generated using AI just to be able to test the E2E.
150
+ def get_soda_data_type_name_by_data_source_data_type_names(self) -> dict[str, SodaDataTypeName]:
151
+ return {
152
+ "varchar": SodaDataTypeName.VARCHAR,
153
+ "char": SodaDataTypeName.CHAR,
154
+ "character": SodaDataTypeName.CHAR,
155
+ "text": SodaDataTypeName.TEXT,
156
+ "string": SodaDataTypeName.VARCHAR,
157
+ "smallint": SodaDataTypeName.SMALLINT,
158
+ "integer": SodaDataTypeName.INTEGER,
159
+ "int": SodaDataTypeName.INTEGER,
160
+ "bigint": SodaDataTypeName.BIGINT,
161
+ "tinyint": SodaDataTypeName.SMALLINT,
162
+ "byteint": SodaDataTypeName.SMALLINT,
163
+ "number": SodaDataTypeName.NUMERIC,
164
+ "decimal": SodaDataTypeName.DECIMAL,
165
+ "numeric": SodaDataTypeName.NUMERIC,
166
+ "float": SodaDataTypeName.FLOAT,
167
+ "float4": SodaDataTypeName.FLOAT,
168
+ "float8": SodaDataTypeName.DOUBLE,
169
+ "double": SodaDataTypeName.DOUBLE,
170
+ "double precision": SodaDataTypeName.DOUBLE,
171
+ "real": SodaDataTypeName.FLOAT,
172
+ "timestamp": SodaDataTypeName.TIMESTAMP,
173
+ "timestamp_ntz": SodaDataTypeName.TIMESTAMP,
174
+ TIMESTAMP_WITHOUT_TIME_ZONE: SodaDataTypeName.TIMESTAMP,
175
+ "datetime": SodaDataTypeName.TIMESTAMP,
176
+ "timestamp_tz": SodaDataTypeName.TIMESTAMP_TZ,
177
+ TIMESTAMP_WITH_TIME_ZONE: SodaDataTypeName.TIMESTAMP_TZ,
178
+ "timestamp_ltz": SodaDataTypeName.TIMESTAMP_TZ,
179
+ TIMESTAMP_WITH_LOCAL_TIME_ZONE: SodaDataTypeName.TIMESTAMP_TZ,
180
+ "date": SodaDataTypeName.DATE,
181
+ "time": SodaDataTypeName.TIME,
182
+ "boolean": SodaDataTypeName.BOOLEAN,
183
+ }
184
+
185
+ def data_type_has_parameter_datetime_precision(self, data_type_name) -> bool:
186
+ return data_type_name.lower() in [
187
+ "timestamp",
188
+ "timestamp_ntz",
189
+ TIMESTAMP_WITHOUT_TIME_ZONE,
190
+ "timestamp_tz",
191
+ TIMESTAMP_WITH_TIME_ZONE,
192
+ "timestamp_ltz",
193
+ TIMESTAMP_WITH_LOCAL_TIME_ZONE,
194
+ ]
195
+
196
+ def _build_sample_sql(self, sampler_type: str, sample_size: Number) -> str:
197
+ if sampler_type == SamplerType.ABSOLUTE_LIMIT:
198
+ return f"TABLESAMPLE ({int(sample_size)} ROWS)"
199
+ else:
200
+ raise ValueError(f"Unsupported sample type: {sampler_type}")
@@ -0,0 +1,140 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from abc import ABC
5
+ from pathlib import Path
6
+ from typing import Dict, Literal, Optional
7
+
8
+ from cryptography.hazmat.backends import default_backend
9
+ from cryptography.hazmat.primitives import serialization
10
+ from pydantic import Field, SecretStr, field_validator
11
+ from snowflake import connector
12
+ from soda_core.common.data_source_connection import DataSourceConnection
13
+ from soda_core.common.logging_constants import soda_logger
14
+ from soda_core.model.data_source.data_source import DataSourceBase
15
+ from soda_core.model.data_source.data_source_connection_properties import (
16
+ DataSourceConnectionProperties,
17
+ )
18
+
19
+ logger: logging.Logger = soda_logger
20
+
21
+ USER_DESCRIPTION = "Username for authentication"
22
+
23
+
24
+ class SnowflakeConnectionProperties(DataSourceConnectionProperties, ABC):
25
+ ...
26
+
27
+
28
+ class SnowflakeSharedConnectionProperties(SnowflakeConnectionProperties, ABC):
29
+ account: str = Field(..., description="Snowflake account identifier")
30
+ warehouse: Optional[str] = Field(None, description="Name of the warehouse to use")
31
+ database: Optional[str] = Field(None, description="Name of the database to use")
32
+ role: Optional[str] = Field(None, description="Role to assume after connecting")
33
+ session_parameters: Optional[Dict[str, str]] = Field(None, description="Session-level parameters")
34
+ host: Optional[str] = Field(None, description="Host name of the Snowflake account")
35
+
36
+
37
+ class SnowflakePasswordAuth(SnowflakeSharedConnectionProperties):
38
+ user: str = Field(..., description=USER_DESCRIPTION)
39
+ password: SecretStr = Field(..., description="User password")
40
+
41
+
42
+ class SnowflakeKeyPairAuth(SnowflakeSharedConnectionProperties):
43
+ user: str = Field(..., description=USER_DESCRIPTION)
44
+ private_key: SecretStr = Field(..., description="Private key for authentication")
45
+ private_key_passphrase: Optional[SecretStr] = Field(None, description="Passphrase if private key is encrypted")
46
+
47
+ def to_connection_kwargs(self) -> dict:
48
+ connection_kwargs = super().to_connection_kwargs()
49
+ connection_kwargs["private_key"] = self._decrypt(self.private_key, self.private_key_passphrase)
50
+ return connection_kwargs
51
+
52
+ def _decrypt(self, private_key: str, private_key_passphrase: Optional[SecretStr]) -> bytes:
53
+ private_key_bytes = private_key.get_secret_value().encode()
54
+ private_key_passphrase_bytes = (
55
+ private_key_passphrase.get_secret_value().encode() if private_key_passphrase else None
56
+ )
57
+
58
+ p_key = serialization.load_pem_private_key(
59
+ private_key_bytes, password=private_key_passphrase_bytes, backend=default_backend()
60
+ )
61
+
62
+ return p_key.private_bytes(
63
+ encoding=serialization.Encoding.DER,
64
+ format=serialization.PrivateFormat.PKCS8,
65
+ encryption_algorithm=serialization.NoEncryption(),
66
+ )
67
+
68
+
69
+ class SnowflakeKeyPairFileAuth(SnowflakeSharedConnectionProperties):
70
+ user: str = Field(..., description=USER_DESCRIPTION)
71
+ private_key_path: Path = Field(..., description="Path to private key file")
72
+ private_key_passphrase: Optional[SecretStr] = Field(None, description="Passphrase if private key is encrypted")
73
+
74
+ def to_connection_kwargs(self) -> dict:
75
+ connection_kwargs = super().to_connection_kwargs()
76
+ connection_kwargs["private_key_file"] = self.private_key_path
77
+ if self.private_key_passphrase is not None:
78
+ pwd = self.private_key_passphrase.get_secret_value()
79
+ if pwd:
80
+ connection_kwargs["private_key_file_pwd"] = pwd
81
+ return connection_kwargs
82
+
83
+
84
+ class SnowflakeOAuthAuth(SnowflakeSharedConnectionProperties):
85
+ authenticator: Literal["oauth"] = Field(..., description="Use OAuth access token")
86
+ token: SecretStr = Field(..., description="OAuth access token")
87
+
88
+
89
+ class SnowflakeClientCredentialsOAuthAuth(SnowflakeSharedConnectionProperties):
90
+ # User is not required for OAuth Client Credentials Flow.
91
+ authenticator: Literal["OAUTH_CLIENT_CREDENTIALS"] = Field(
92
+ ..., description="Authenticator to use for OAuth Client Credentials Flow."
93
+ )
94
+ oauth_client_id: SecretStr = Field(..., description="Client ID for OAuth Client Credentials Flow.")
95
+ oauth_client_secret: SecretStr = Field(..., description="Client secret for OAuth Client Credentials Flow.")
96
+ oauth_token_request_url: SecretStr = Field(..., description="Token request URL for OAuth Client Credentials Flow.")
97
+ oauth_scope: Optional[SecretStr] = Field(None, description="Scope for OAuth Client Credentials Flow if required.")
98
+
99
+
100
+ class SnowflakeSSOAuth(SnowflakeSharedConnectionProperties):
101
+ # User is not required for SSO Flow.
102
+ authenticator: Literal["externalbrowser"] = Field("externalbrowser", description="Use external browser SSO login")
103
+
104
+
105
+ class SnowflakeDataSource(DataSourceBase, ABC):
106
+ type: Literal["snowflake"] = Field("snowflake")
107
+ connection_properties: SnowflakeConnectionProperties = Field(
108
+ ..., alias="connection", description="Snowflake connection configuration"
109
+ )
110
+
111
+ @field_validator("connection_properties", mode="before")
112
+ @classmethod
113
+ def infer_connection_type(cls, value):
114
+ if "password" in value:
115
+ return SnowflakePasswordAuth(**value)
116
+ elif "private_key" in value:
117
+ return SnowflakeKeyPairAuth(**value)
118
+ elif "private_key_path" in value:
119
+ return SnowflakeKeyPairFileAuth(**value)
120
+ elif "token" in value:
121
+ return SnowflakeOAuthAuth(**value)
122
+ elif value.get("authenticator") == "externalbrowser":
123
+ return SnowflakeSSOAuth(**value)
124
+ elif value.get("authenticator") == "OAUTH_CLIENT_CREDENTIALS":
125
+ return SnowflakeClientCredentialsOAuthAuth(**value)
126
+ raise ValueError("Could not infer Snowflake connection type from input")
127
+
128
+
129
+ class SnowflakeDataSourceConnection(DataSourceConnection):
130
+ def __init__(self, name: str, connection_properties: DataSourceConnectionProperties):
131
+ super().__init__(name, connection_properties)
132
+
133
+ def _create_connection(
134
+ self,
135
+ config: SnowflakeConnectionProperties,
136
+ ):
137
+ return connector.connect(
138
+ application="Soda",
139
+ **config.to_connection_kwargs(),
140
+ )
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ from helpers.data_source_test_helper import DataSourceTestHelper
6
+
7
+
8
+ class SnowflakeDataSourceTestHelper(DataSourceTestHelper):
9
+ def _create_database_name(self) -> str | None:
10
+ return os.getenv("SNOWFLAKE_DATABASE", "soda_test")
11
+
12
+ def _create_data_source_yaml_str(self) -> str:
13
+ """
14
+ Called in _create_data_source_impl to initialized self.data_source_impl
15
+ self.database_name and self.schema_name are available if appropriate for the data source type
16
+ """
17
+ return f"""
18
+ type: snowflake
19
+ name: {self.name}
20
+ connection:
21
+ account: {os.getenv("SNOWFLAKE_ACCOUNT")}
22
+ user: {os.getenv("SNOWFLAKE_USER")}
23
+ password: {os.getenv("SNOWFLAKE_PASSWORD")}
24
+ database: {self.dataset_prefix[0]}
25
+ """
26
+
27
+ def _adjust_schema_name(self, schema_name: str) -> str:
28
+ return schema_name.upper()
@@ -0,0 +1,6 @@
1
+ Metadata-Version: 2.4
2
+ Name: soda-snowflake
3
+ Version: 4.0.5
4
+ Requires-Dist: soda-core==4.0.5
5
+ Requires-Dist: snowflake-connector-python>=3.0
6
+ Dynamic: requires-dist
@@ -0,0 +1,10 @@
1
+ setup.py
2
+ src/soda_snowflake.egg-info/PKG-INFO
3
+ src/soda_snowflake.egg-info/SOURCES.txt
4
+ src/soda_snowflake.egg-info/dependency_links.txt
5
+ src/soda_snowflake.egg-info/entry_points.txt
6
+ src/soda_snowflake.egg-info/requires.txt
7
+ src/soda_snowflake.egg-info/top_level.txt
8
+ src/soda_snowflake/common/data_sources/snowflake_data_source.py
9
+ src/soda_snowflake/common/data_sources/snowflake_data_source_connection.py
10
+ src/soda_snowflake/test_helpers/snowflake_data_source_test_helper.py
@@ -0,0 +1,2 @@
1
+ [soda.plugins.data_source.snowflake]
2
+ SnowflakeDataSourceImpl = soda_snowflake.common.data_sources.snowflake_data_source:SnowflakeDataSourceImpl
@@ -0,0 +1,2 @@
1
+ soda-core==4.0.5
2
+ snowflake-connector-python>=3.0
@@ -0,0 +1 @@
1
+ soda_snowflake