soda-sqlserver 4.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ Metadata-Version: 2.4
2
+ Name: soda-sqlserver
3
+ Version: 4.0.5
4
+ Requires-Dist: soda-core==4.0.5
5
+ Requires-Dist: pyodbc
6
+ Dynamic: requires-dist
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,24 @@
1
+ #!/usr/bin/env python
2
+
3
+ from setuptools import setup
4
+
5
+ package_name = "soda-sqlserver"
6
+ package_version = "4.0.5"
7
+ description = "Soda SQL Server V4"
8
+
9
+ requires = [
10
+ f"soda-core=={package_version}",
11
+ "pyodbc",
12
+ ]
13
+
14
+ setup(
15
+ name=package_name,
16
+ version=package_version,
17
+ install_requires=requires,
18
+ package_dir={"": "src"},
19
+ entry_points={
20
+ "soda.plugins.data_source.sqlserver": [
21
+ "SqlServerDataSourceImpl = soda_sqlserver.common.data_sources.sqlserver_data_source:SqlServerDataSourceImpl",
22
+ ],
23
+ },
24
+ )
@@ -0,0 +1,427 @@
1
+ import logging
2
+ from copy import deepcopy
3
+ from datetime import date, datetime
4
+ from typing import Optional
5
+
6
+ from soda_core.common.data_source_connection import DataSourceConnection
7
+ from soda_core.common.data_source_impl import DataSourceImpl
8
+ from soda_core.common.dataset_identifier import DatasetIdentifier
9
+ from soda_core.common.logging_constants import soda_logger
10
+ from soda_core.common.metadata_types import SodaDataTypeName, SqlDataType
11
+ from soda_core.common.sql_ast import (
12
+ AND,
13
+ COLUMN,
14
+ COUNT,
15
+ CREATE_TABLE,
16
+ CREATE_TABLE_AS_SELECT,
17
+ CREATE_TABLE_IF_NOT_EXISTS,
18
+ CREATE_VIEW,
19
+ DISTINCT,
20
+ DROP_TABLE,
21
+ DROP_TABLE_IF_EXISTS,
22
+ DROP_VIEW,
23
+ DROP_VIEW_IF_EXISTS,
24
+ FROM,
25
+ INSERT_INTO,
26
+ INSERT_INTO_VIA_SELECT,
27
+ INTO,
28
+ LENGTH,
29
+ LIMIT,
30
+ OFFSET,
31
+ ORDER_BY_ASC,
32
+ REGEX_LIKE,
33
+ SELECT,
34
+ STAR,
35
+ STRING_HASH,
36
+ TUPLE,
37
+ VALUES,
38
+ WHERE,
39
+ WITH,
40
+ SqlExpressionStr,
41
+ )
42
+ from soda_core.common.sql_dialect import SqlDialect
43
+ from soda_sqlserver.common.data_sources.sqlserver_data_source_connection import (
44
+ SqlServerDataSource as SqlServerDataSourceModel,
45
+ )
46
+ from soda_sqlserver.common.data_sources.sqlserver_data_source_connection import (
47
+ SqlServerDataSourceConnection,
48
+ )
49
+
50
+ logger: logging.Logger = soda_logger
51
+
52
+
53
+ class SqlServerDataSourceImpl(DataSourceImpl, model_class=SqlServerDataSourceModel):
54
+ def __init__(self, data_source_model: SqlServerDataSourceModel, connection: Optional[DataSourceConnection] = None):
55
+ super().__init__(data_source_model=data_source_model, connection=connection)
56
+
57
+ def _create_sql_dialect(self) -> SqlDialect:
58
+ return SqlServerSqlDialect(data_source_impl=self)
59
+
60
+ def _create_data_source_connection(self) -> DataSourceConnection:
61
+ return SqlServerDataSourceConnection(
62
+ name=self.data_source_model.name, connection_properties=self.data_source_model.connection_properties
63
+ )
64
+
65
+
66
+ class SqlServerSqlDialect(SqlDialect):
67
+ DEFAULT_QUOTE_CHAR = "[" # Do not use this! Always use quote_default()
68
+ SODA_DATA_TYPE_SYNONYMS = ((SodaDataTypeName.TEXT, SodaDataTypeName.VARCHAR),)
69
+
70
+ def build_select_sql(self, select_elements: list, add_semicolon: bool = True) -> str:
71
+ statement_lines: list[str] = []
72
+ statement_lines.extend(self._build_cte_sql_lines(select_elements))
73
+ statement_lines.extend(self._build_select_sql_lines(select_elements))
74
+ statement_lines.extend(self._build_into_sql_lines(select_elements))
75
+ statement_lines.extend(self._build_from_sql_lines(select_elements))
76
+ statement_lines.extend(self._build_where_sql_lines(select_elements))
77
+ statement_lines.extend(self._build_group_by_sql_lines(select_elements))
78
+ statement_lines.extend(self._build_order_by_lines(select_elements))
79
+
80
+ offset_line = self._build_offset_line(select_elements)
81
+ if offset_line:
82
+ statement_lines.append(offset_line)
83
+
84
+ limit_line = self._build_limit_line(select_elements)
85
+ if limit_line:
86
+ statement_lines.append(limit_line)
87
+
88
+ return "\n".join(statement_lines) + (";" if add_semicolon else "")
89
+
90
+ def _build_select_sql_lines(self, select_elements: list) -> list[str]:
91
+ # Use the default implementation, but we need to handle the case where the select elements contain a LIMIT statement.
92
+ select_sql_lines: list[str] = super()._build_select_sql_lines(select_elements)
93
+ if self.__requires_select_top(select_elements):
94
+ limit_element: LIMIT = [
95
+ select_element for select_element in select_elements if isinstance(select_element, LIMIT)
96
+ ][0]
97
+ select_sql_lines[0] = select_sql_lines[0].replace("SELECT ", f"SELECT TOP {limit_element.limit} ")
98
+ return select_sql_lines
99
+
100
+ def __requires_select_top(self, select_elements: list) -> bool:
101
+ # We require TOP when there is a LIMIT statement and no OFFSET statement.
102
+ return any(isinstance(select_element, LIMIT) for select_element in select_elements) and not any(
103
+ isinstance(select_element, OFFSET) for select_element in select_elements
104
+ )
105
+
106
+ def _build_limit_line(self, select_elements: list) -> Optional[str]:
107
+ # First, check if there is a LIMIT statement in the select elements.
108
+ limit_statement_present = any(isinstance(select_element, LIMIT) for select_element in select_elements)
109
+ if not limit_statement_present:
110
+ return None
111
+
112
+ # Check if there is an OFFSET statement in the select elements. If so, use the default logic.
113
+ uses_offset = any(isinstance(select_element, OFFSET) for select_element in select_elements)
114
+ if uses_offset:
115
+ return super()._build_limit_line(select_elements)
116
+ else:
117
+ return None # This case (limit, but no offset) is handled by the _build_select_sql_lines method; it adds TOP N instead of FETCH NEXT.
118
+
119
+ def literal_date(self, date: date):
120
+ """Technically dates can be passed directly as strings, but this is more explicit."""
121
+ date_string = date.strftime("%Y-%m-%d")
122
+ return f"CAST('{date_string}' AS DATE)"
123
+
124
+ def literal_datetime(self, datetime: datetime):
125
+ return f"'{datetime.isoformat(timespec='milliseconds')}'"
126
+
127
+ def literal_boolean(self, boolean: bool):
128
+ return "1" if boolean is True else "0"
129
+
130
+ def quote_default(self, identifier: Optional[str]) -> Optional[str]:
131
+ return f"[{identifier}]" if isinstance(identifier, str) and len(identifier) > 0 else None
132
+
133
+ def create_schema_if_not_exists_sql(self, prefixes: list[str], add_semicolon: bool = True) -> str:
134
+ schema_name: str = prefixes[1]
135
+ return f"""
136
+ IF NOT EXISTS ( SELECT *
137
+ FROM sys.schemas
138
+ WHERE name = N'{schema_name}' )
139
+ EXEC('CREATE SCHEMA [{schema_name}]')
140
+ """ + (
141
+ ";" if add_semicolon else ""
142
+ )
143
+
144
+ def build_drop_table_sql(self, drop_table: DROP_TABLE | DROP_TABLE_IF_EXISTS, add_semicolon: bool = True) -> str:
145
+ if_exists_sql: str = (
146
+ f"IF OBJECT_ID('{drop_table.fully_qualified_table_name}', 'U') IS NOT NULL"
147
+ if isinstance(drop_table, DROP_TABLE_IF_EXISTS)
148
+ else ""
149
+ )
150
+ return f"{if_exists_sql} DROP TABLE {drop_table.fully_qualified_table_name}" + (";" if add_semicolon else "")
151
+
152
+ def _build_create_table_statement_sql(self, create_table: CREATE_TABLE | CREATE_TABLE_IF_NOT_EXISTS) -> str:
153
+ if_not_exists_sql: str = (
154
+ f"IF OBJECT_ID('{create_table.fully_qualified_table_name}', 'U') IS NULL"
155
+ if isinstance(create_table, CREATE_TABLE_IF_NOT_EXISTS)
156
+ else ""
157
+ )
158
+ create_table_sql: str = f"{if_not_exists_sql} CREATE TABLE {create_table.fully_qualified_table_name} "
159
+ return create_table_sql
160
+
161
+ def _build_length_sql(self, length: LENGTH) -> str:
162
+ return f"LEN({self.build_expression_sql(length.expression)})"
163
+
164
+ def sql_expr_timestamp_literal(self, datetime_in_iso8601: str) -> str:
165
+ return f"'{datetime_in_iso8601}'"
166
+
167
+ def sql_expr_timestamp_truncate_day(self, timestamp_literal: str) -> str:
168
+ return f"DATETRUNC(DAY, {timestamp_literal})"
169
+
170
+ def sql_expr_timestamp_add_day(self, timestamp_literal: str) -> str:
171
+ return f"DATEADD(DAY, 1, {timestamp_literal})"
172
+
173
+ def _build_tuple_sql(self, tuple: TUPLE) -> str:
174
+ if tuple.check_context(COUNT) and tuple.check_context(DISTINCT):
175
+ return f"CHECKSUM{super()._build_tuple_sql(tuple)}"
176
+ if tuple.check_context(VALUES):
177
+ # in built_cte_values_sql, elements are dropped in top-level select statement, so can't use parentheses
178
+ return ", ".join(self.build_expression_sql(e) for e in tuple.expressions)
179
+ return super()._build_tuple_sql(tuple)
180
+
181
+ def _build_regex_like_sql(self, matches: REGEX_LIKE) -> str:
182
+ expression: str = self.build_expression_sql(matches.expression)
183
+ regex_pattern = matches.regex_pattern
184
+ # alpha expansion doesn't work properly for case sensitive ranges in SQLServer
185
+ # this is quite a hack to fit the common use-cases. generally regex's are only partially supported anyway
186
+ regex_pattern = regex_pattern.replace("a-z", "abcdefghijklmnopqrstuvwxyz")
187
+ regex_pattern = regex_pattern.replace("A-Z", "ABCDEFGHIJKLMNOPQRSTUVWXYZ")
188
+ # collations define rules for sorting strings and distinguishing similar characters
189
+ # see: https://learn.microsoft.com/en-us/sql/relational-databases/collations/collation-and-unicode-support?view=sql-server-ver17
190
+ # CS: Case sensitive; AS: Accent sensitive
191
+ # The default is SQL_Latin1_General_Cp1_CI_AS (case-insensitive), we replcae with a case sensitive collation
192
+ return f"PATINDEX ('%{regex_pattern}%', {expression} COLLATE SQL_Latin1_General_Cp1_CS_AS) > 0"
193
+
194
+ def supports_regex_advanced(self) -> bool:
195
+ return False
196
+
197
+ def build_cte_values_sql(self, values: VALUES, alias_columns: list[COLUMN] | None) -> str:
198
+ return "\nUNION ALL\n".join(["SELECT " + self.build_expression_sql(value) for value in values.values])
199
+
200
+ def select_all_paginated_sql(
201
+ self,
202
+ dataset_identifier: DatasetIdentifier,
203
+ columns: list[str],
204
+ filter: Optional[str],
205
+ order_by: list[str],
206
+ limit: int,
207
+ offset: int,
208
+ ) -> str:
209
+ where_clauses = []
210
+
211
+ if filter:
212
+ where_clauses.append(SqlExpressionStr(filter))
213
+
214
+ statements = [
215
+ SELECT(columns or [STAR()]),
216
+ FROM(table_name=dataset_identifier.dataset_name, table_prefix=dataset_identifier.prefixes),
217
+ WHERE.optional(AND.optional(where_clauses)),
218
+ *[ORDER_BY_ASC(c) for c in order_by],
219
+ OFFSET(offset),
220
+ LIMIT(limit),
221
+ ]
222
+
223
+ return self.build_select_sql(statements)
224
+
225
+ def _build_limit_sql(self, limit_element: LIMIT) -> str:
226
+ return f"FETCH NEXT {limit_element.limit} ROWS ONLY"
227
+
228
+ def _build_offset_sql(self, offset_element: OFFSET) -> str:
229
+ return f"OFFSET {offset_element.offset} ROWS"
230
+
231
+ def _get_data_type_name_synonyms(self) -> list[list[str]]:
232
+ return [
233
+ ["varchar", "nvarchar"],
234
+ ["char", "nchar"],
235
+ ["int", "integer"],
236
+ ["bigint"],
237
+ ["smallint"],
238
+ ["real"],
239
+ ["float", "double precision"],
240
+ ["datetime2", "datetime"],
241
+ ]
242
+
243
+ # copied from redshift
244
+ def get_data_source_data_type_name_by_soda_data_type_names(self) -> dict:
245
+ return {
246
+ SodaDataTypeName.CHAR: "char",
247
+ SodaDataTypeName.VARCHAR: "varchar",
248
+ SodaDataTypeName.TEXT: "varchar",
249
+ SodaDataTypeName.SMALLINT: "smallint", #
250
+ SodaDataTypeName.INTEGER: "int", #
251
+ SodaDataTypeName.BIGINT: "bigint", #
252
+ SodaDataTypeName.NUMERIC: "numeric", #
253
+ SodaDataTypeName.DECIMAL: "decimal", #
254
+ SodaDataTypeName.FLOAT: "real", #
255
+ SodaDataTypeName.DOUBLE: "float",
256
+ SodaDataTypeName.TIMESTAMP: "datetime2",
257
+ SodaDataTypeName.TIMESTAMP_TZ: "datetimeoffset",
258
+ SodaDataTypeName.DATE: "date",
259
+ SodaDataTypeName.TIME: "time",
260
+ SodaDataTypeName.BOOLEAN: "bit",
261
+ }
262
+
263
+ # copied from redshift
264
+ def get_soda_data_type_name_by_data_source_data_type_names(self) -> dict[str, SodaDataTypeName]:
265
+ return {
266
+ # Character types
267
+ "char": SodaDataTypeName.CHAR,
268
+ "varchar": SodaDataTypeName.VARCHAR,
269
+ "text": SodaDataTypeName.TEXT,
270
+ "nchar": SodaDataTypeName.CHAR,
271
+ "nvarchar": SodaDataTypeName.VARCHAR,
272
+ "ntext": SodaDataTypeName.TEXT,
273
+ # Integer types
274
+ "tinyint": SodaDataTypeName.SMALLINT,
275
+ "smallint": SodaDataTypeName.SMALLINT,
276
+ "int": SodaDataTypeName.INTEGER,
277
+ "bigint": SodaDataTypeName.BIGINT,
278
+ # Exact numeric types
279
+ "numeric": SodaDataTypeName.NUMERIC,
280
+ "decimal": SodaDataTypeName.DECIMAL,
281
+ # Approximate numeric types
282
+ "real": SodaDataTypeName.FLOAT,
283
+ "float": SodaDataTypeName.DOUBLE,
284
+ # Date/time types
285
+ "date": SodaDataTypeName.DATE,
286
+ "time": SodaDataTypeName.TIME,
287
+ "datetime2": SodaDataTypeName.TIMESTAMP,
288
+ "datetimeoffset": SodaDataTypeName.TIMESTAMP_TZ,
289
+ "datetime": SodaDataTypeName.TIMESTAMP,
290
+ "smalldatetime": SodaDataTypeName.TIMESTAMP,
291
+ # Boolean type
292
+ "bit": SodaDataTypeName.BOOLEAN,
293
+ }
294
+
295
+ def supports_data_type_character_maximum_length(self) -> bool:
296
+ return True
297
+
298
+ def supports_data_type_numeric_precision(self) -> bool:
299
+ return True
300
+
301
+ def supports_data_type_numeric_scale(self) -> bool:
302
+ return True
303
+
304
+ def supports_data_type_datetime_precision(self) -> bool:
305
+ return True
306
+
307
+ def supports_datetime_microseconds(self) -> bool:
308
+ return False
309
+
310
+ def data_type_has_parameter_character_maximum_length(self, data_type_name) -> bool:
311
+ return data_type_name.lower() in ["varchar", "char", "nvarchar", "nchar"]
312
+
313
+ def data_type_has_parameter_numeric_precision(self, data_type_name) -> bool:
314
+ return data_type_name.lower() in ["numeric", "decimal", "float"]
315
+
316
+ def data_type_has_parameter_numeric_scale(self, data_type_name) -> bool:
317
+ return data_type_name.lower() in ["numeric", "decimal"]
318
+
319
+ def data_type_has_parameter_datetime_precision(self, data_type_name) -> bool:
320
+ return data_type_name.lower() in [
321
+ "time",
322
+ "datetime2",
323
+ "datetimeoffset",
324
+ ]
325
+
326
+ def default_varchar_length(self) -> Optional[int]:
327
+ return 255
328
+
329
+ def is_quoted(self, identifier: str) -> bool:
330
+ return identifier.startswith("[") and identifier.endswith("]")
331
+
332
+ def build_insert_into_sql(self, insert_into: INSERT_INTO, add_semicolon: bool = True) -> str:
333
+ # SqlServer supports a max of 1000 rows in an insert statement. If that's the case, split the insert into multiple statements and recursively call this function.
334
+ STEP_SIZE = self.get_preferred_number_of_rows_for_insert()
335
+ if len(insert_into.values) > STEP_SIZE:
336
+ final_insert_sql = ""
337
+ for i in range(0, len(insert_into.values), STEP_SIZE):
338
+ temp_insert_into = INSERT_INTO(
339
+ fully_qualified_table_name=insert_into.fully_qualified_table_name,
340
+ columns=insert_into.columns,
341
+ values=insert_into.values[i : i + STEP_SIZE],
342
+ )
343
+ final_insert_sql += self.build_insert_into_sql(
344
+ temp_insert_into, add_semicolon=True
345
+ ) # Now we force the semicolon to separate the statements
346
+ final_insert_sql += "\n"
347
+ return final_insert_sql
348
+
349
+ return super().build_insert_into_sql(insert_into, add_semicolon=add_semicolon)
350
+
351
+ def build_insert_into_via_select_sql(
352
+ self, insert_into_via_select: INSERT_INTO_VIA_SELECT, add_semicolon: bool = True
353
+ ) -> str:
354
+ # First get all the WITH clauses from the select elements.
355
+ with_clauses: list[str] = []
356
+ remaining_select_elements: list[str] = []
357
+ for select_element in insert_into_via_select.select_elements:
358
+ if isinstance(select_element, WITH):
359
+ with_clauses.append(select_element)
360
+ else: # Split of the other elements
361
+ remaining_select_elements.append(select_element)
362
+ # Then build the with statements.
363
+ with_statements: str = "\n".join(self._build_cte_sql_lines(with_clauses))
364
+ insert_into_sql: str = f"{with_statements}\nINSERT INTO {insert_into_via_select.fully_qualified_table_name}\n"
365
+ insert_into_sql += self._build_insert_into_columns_sql(insert_into_via_select) + "\n"
366
+ insert_into_sql += "(\n" + self.build_select_sql(remaining_select_elements, add_semicolon=False) + "\n)"
367
+ return insert_into_sql + (";" if add_semicolon else "")
368
+
369
+ def get_preferred_number_of_rows_for_insert(self) -> int:
370
+ return 1000
371
+
372
+ def map_test_sql_data_type_to_data_source(self, source_data_type: SqlDataType) -> SqlDataType:
373
+ """SQLServer always requires a varchar length in create table statements."""
374
+ sql_data_type = super().map_test_sql_data_type_to_data_source(source_data_type)
375
+ if sql_data_type.name == "varchar" and sql_data_type.character_maximum_length is None:
376
+ sql_data_type.character_maximum_length = self.default_varchar_length()
377
+ return sql_data_type
378
+
379
+ @classmethod
380
+ def is_same_soda_data_type_with_synonyms(cls, expected: SodaDataTypeName, actual: SodaDataTypeName) -> bool:
381
+ if expected == SodaDataTypeName.CHAR and actual == SodaDataTypeName.VARCHAR:
382
+ logger.debug(
383
+ f"In is_same_soda_data_type_with_synonyms, expected {expected} and actual {actual} are treated as the same because of SQLServer cursor not distinguishing between varchar and char"
384
+ )
385
+ return True
386
+ elif expected == SodaDataTypeName.NUMERIC and actual == SodaDataTypeName.DECIMAL:
387
+ logger.debug(
388
+ f"In is_same_soda_data_type_with_synonyms, expected {expected} and actual {actual} are treated as the same because of SQLServer cursor not distinguishing between numeric and decimal"
389
+ )
390
+ return True
391
+ elif expected == SodaDataTypeName.TIMESTAMP_TZ and actual == SodaDataTypeName.VARCHAR:
392
+ logger.debug(
393
+ f"In is_same_soda_data_type_with_synonyms, expected {expected} and actual {actual} are treated as the same because of SQLServer cursor returns varchar for timestamps with timezone"
394
+ )
395
+ return True
396
+ return super().is_same_soda_data_type_with_synonyms(expected, actual)
397
+
398
+ def _build_string_hash_sql(self, string_hash: STRING_HASH) -> str:
399
+ return f"CONVERT(VARCHAR(32), HASHBYTES('MD5', {self.build_expression_sql(string_hash.expression)}), 2)"
400
+
401
+ def _get_add_column_sql_expr(self) -> str:
402
+ return "ADD"
403
+
404
+ def build_create_table_as_select_sql(
405
+ self, create_table_as_select: CREATE_TABLE_AS_SELECT, add_semicolon: bool = True, add_parenthesis: bool = True
406
+ ) -> str:
407
+ # Copy the select elements and insert an INTO with the same table name as the create table as select statement
408
+ select_elements = create_table_as_select.select_elements.copy()
409
+ select_elements += [INTO(fully_qualified_table_name=create_table_as_select.fully_qualified_table_name)]
410
+ result_sql: str = self.build_select_sql(select_elements, add_semicolon=add_semicolon)
411
+ return result_sql
412
+
413
+ def build_drop_view_sql(self, drop_view: DROP_VIEW | DROP_VIEW_IF_EXISTS, add_semicolon: bool = True) -> str:
414
+ # SqlServer does not allow for the database name to be specified in the view name, so we need to drop it.
415
+ drop_view_copy = deepcopy(drop_view) # Copy the object so we don't modify the original object
416
+ # Drop the first prefix (database name) from the fully qualified view name
417
+ drop_view_copy.fully_qualified_view_name = ".".join(drop_view_copy.fully_qualified_view_name.split(".")[1:])
418
+ return super().build_drop_view_sql(drop_view_copy, add_semicolon)
419
+
420
+ def build_create_view_sql(
421
+ self, create_view: CREATE_VIEW, add_semicolon: bool = True, add_parenthesis: bool = True
422
+ ) -> str:
423
+ # SqlServer does not allow for the database name to be specified in the view name, so we need to drop it.
424
+ create_view_copy = deepcopy(create_view) # Copy the object so we don't modify the original object
425
+ # Drop the first prefix (database name) from the fully qualified view name
426
+ create_view_copy.fully_qualified_view_name = ".".join(create_view_copy.fully_qualified_view_name.split(".")[1:])
427
+ return super().build_create_view_sql(create_view_copy, add_semicolon, add_parenthesis=False)
@@ -0,0 +1,193 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import struct
5
+ from abc import ABC
6
+ from datetime import datetime, timedelta, timezone
7
+ from typing import Literal, Optional, Union
8
+
9
+ import pyodbc
10
+ from pydantic import Field, SecretStr
11
+ from soda_core.__version__ import SODA_CORE_VERSION
12
+ from soda_core.common.data_source_connection import DataSourceConnection
13
+ from soda_core.common.exceptions import DataSourceConnectionException
14
+ from soda_core.common.logging_constants import soda_logger
15
+ from soda_core.model.data_source.data_source import DataSourceBase
16
+ from soda_core.model.data_source.data_source_connection_properties import (
17
+ DataSourceConnectionProperties,
18
+ )
19
+
20
+ logger: logging.Logger = soda_logger
21
+
22
+
23
+ CONTEXT_AUTHENTICATION_DESCRIPTION = "Use context authentication"
24
+ USER_DESCRIPTION = "Username for authentication"
25
+ DEFAULT_PORT = 1433
26
+
27
+
28
+ class SqlServerConnectionProperties(DataSourceConnectionProperties, ABC):
29
+ host: str = Field(..., description="Host name of the SQL Server instance")
30
+ port: int = Field(DEFAULT_PORT, description="Port number of the SQL Server instance")
31
+ database: str = Field(..., description="Name of the database to use")
32
+
33
+ # Optional fields
34
+ driver: Optional[str] = Field(
35
+ "ODBC Driver 18 for SQL Server", description="Driver name for the SQL Server instance"
36
+ )
37
+ trust_server_certificate: Optional[bool] = Field(False, description="Whether to trust the server certificate")
38
+ trusted_connection: Optional[bool] = Field(False, description="Whether to use trusted connection")
39
+ encrypt: Optional[bool] = Field(False, description="Whether to encrypt the connection")
40
+ connection_max_retries: Optional[int] = Field(0, description="Maximum number of connection retries")
41
+ enable_tracing: Optional[bool] = Field(False, description="Whether to enable tracing")
42
+ login_timeout: Optional[int] = Field(0, description="Login timeout")
43
+ scope: Optional[str] = Field(None, description="Scope for the connection")
44
+ connection_parameters: Optional[dict[str, str]] = Field(None, description="Connection parameters")
45
+
46
+
47
+ class SqlServerPasswordAuth(SqlServerConnectionProperties):
48
+ """SQL Server authentication using password"""
49
+
50
+ user: str = Field(..., description=USER_DESCRIPTION)
51
+ password: SecretStr = Field(..., description="Password for authentication")
52
+ authentication: Literal["sql"] = "sql"
53
+
54
+
55
+ class SqlServerActiveDirectoryAuthentication(SqlServerConnectionProperties):
56
+ authentication: Literal[
57
+ "activedirectoryinteractive", "activedirectorypassword", "activedirectoryserviceprincipal"
58
+ ] = Field(..., description="Authentication type")
59
+
60
+
61
+ class SqlServerActiveDirectoryInteractiveAuthentication(SqlServerActiveDirectoryAuthentication):
62
+ user: str = Field(..., description=USER_DESCRIPTION)
63
+ authentication: Literal["activedirectoryinteractive"] = "activedirectoryinteractive"
64
+
65
+
66
+ class SqlServerActiveDirectoryPasswordAuthentication(SqlServerActiveDirectoryAuthentication):
67
+ authentication: Literal["activedirectorypassword"] = "activedirectorypassword"
68
+ user: str = Field(..., description=USER_DESCRIPTION)
69
+ password: SecretStr = Field(..., description="Password for authentication")
70
+
71
+
72
+ class SqlServerActiveDirectoryServicePrincipalAuthentication(SqlServerActiveDirectoryAuthentication):
73
+ authentication: Literal["activedirectoryserviceprincipal"] = "activedirectoryserviceprincipal"
74
+ client_id: str = Field(..., description="Client ID for authentication")
75
+ client_secret: SecretStr = Field(..., description="Client secret for authentication")
76
+
77
+
78
+ class SqlServerDataSource(DataSourceBase, ABC):
79
+ type: Literal["sqlserver"] = Field("sqlserver")
80
+
81
+ connection_properties: Union[
82
+ SqlServerPasswordAuth,
83
+ SqlServerActiveDirectoryInteractiveAuthentication,
84
+ SqlServerActiveDirectoryPasswordAuthentication,
85
+ SqlServerActiveDirectoryServicePrincipalAuthentication,
86
+ ] = Field(..., alias="connection", description="SQL Server connection configuration")
87
+
88
+
89
+ def handle_datetime(dto_value):
90
+ tup = struct.unpack("<6hI2h", dto_value) # e.g., (2017, 3, 16, 10, 35, 18, 500000000, -6, 0)
91
+ return datetime(tup[0], tup[1], tup[2], tup[3], tup[4], tup[5], tup[6] // 1000)
92
+
93
+
94
+ def handle_datetimeoffset(dto_value):
95
+ tup = struct.unpack("<6hI2h", dto_value) # e.g., (2017, 3, 16, 10, 35, 18, 500000000, -6, 0)
96
+ return datetime(
97
+ tup[0],
98
+ tup[1],
99
+ tup[2],
100
+ tup[3],
101
+ tup[4],
102
+ tup[5],
103
+ tup[6] // 1000,
104
+ timezone(timedelta(hours=tup[7], minutes=tup[8])),
105
+ )
106
+
107
+
108
+ class SqlServerDataSourceConnection(DataSourceConnection):
109
+ def __init__(self, name: str, connection_properties: DataSourceConnectionProperties):
110
+ super().__init__(name, connection_properties)
111
+
112
+ def build_connection_string(self, config: SqlServerConnectionProperties):
113
+ conn_params = []
114
+
115
+ conn_params.append(f"DRIVER={{{config.driver}}}")
116
+ conn_params.append(f"DATABASE={config.database}")
117
+
118
+ if "\\" in config.host:
119
+ # If there is a backslash in the host name, the host is a
120
+ # SQL Server named instance. In this case then port number has to be omitted.
121
+ conn_params.append(f"SERVER={config.host}")
122
+ else:
123
+ conn_params.append(f"SERVER={config.host},{int(config.port)}")
124
+
125
+ if config.trusted_connection:
126
+ conn_params.append("Trusted_Connection=YES")
127
+
128
+ if config.trust_server_certificate:
129
+ conn_params.append("TrustServerCertificate=YES")
130
+
131
+ if config.encrypt:
132
+ conn_params.append("Encrypt=YES")
133
+
134
+ if int(config.connection_max_retries) > 0:
135
+ conn_params.append(f"ConnectRetryCount={int(self.connection_max_retries)}")
136
+
137
+ if config.enable_tracing:
138
+ conn_params.append("SQL_ATTR_TRACE=SQL_OPT_TRACE_ON")
139
+
140
+ if config.authentication.lower() == "sql":
141
+ conn_params.append(f"UID={{{config.user}}}")
142
+ conn_params.append(f"PWD={{{config.password.get_secret_value()}}}")
143
+ elif config.authentication.lower() == "activedirectoryinteractive":
144
+ conn_params.append("Authentication=ActiveDirectoryInteractive")
145
+ conn_params.append(f"UID={{{config.user}}}")
146
+ elif config.authentication.lower() == "activedirectorypassword":
147
+ conn_params.append("Authentication=ActiveDirectoryPassword")
148
+ conn_params.append(f"UID={{{config.user}}}")
149
+ conn_params.append(f"PWD={{{config.password.get_secret_value()}}}")
150
+ elif config.authentication.lower() == "activedirectoryserviceprincipal":
151
+ conn_params.append("Authentication=ActiveDirectoryServicePrincipal")
152
+ conn_params.append(f"UID={{{config.client_id}}}")
153
+ conn_params.append(f"PWD={{{config.client_secret.get_secret_value()}}}")
154
+ elif "activedirectory" in config.authentication.lower():
155
+ conn_params.append(f"Authentication={config.authentication}")
156
+
157
+ if config.connection_parameters:
158
+ for key, value in config.connection_parameters.items():
159
+ logger.info(f"Adding connection parameter: {key}={value}")
160
+ conn_params.append(f"{key}={value}")
161
+
162
+ conn_params.append(f"APP=soda-core-fabric/{SODA_CORE_VERSION}")
163
+
164
+ conn_str = ";".join(conn_params)
165
+
166
+ return conn_str
167
+
168
+ def _get_pyodbc_attrs(self) -> dict[int, bytes] | None:
169
+ return None
170
+
171
+ def _create_connection(
172
+ self,
173
+ config: SqlServerConnectionProperties,
174
+ ):
175
+ try:
176
+ self.connection = pyodbc.connect(
177
+ self.build_connection_string(config),
178
+ attrs_before=self._get_pyodbc_attrs(),
179
+ timeout=int(config.login_timeout),
180
+ autocommit=self._get_autocommit_setting(),
181
+ )
182
+
183
+ self.connection.add_output_converter(-155, handle_datetimeoffset)
184
+ self.connection.add_output_converter(-150, handle_datetime)
185
+ return self.connection
186
+ except Exception as e:
187
+ raise DataSourceConnectionException(e) from e
188
+
189
+ def _execute_query_get_result_row_column_name(self, column) -> str:
190
+ return column[0]
191
+
192
+ def _get_autocommit_setting(self) -> bool:
193
+ return False # No need to set autocommit, as it is set to False by default.
@@ -0,0 +1,65 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Optional
5
+
6
+ from helpers.data_source_test_helper import DataSourceTestHelper
7
+ from soda_core.common.sql_ast import DROP_TABLE, DROP_VIEW
8
+ from soda_sqlserver.common.data_sources.sqlserver_data_source import (
9
+ SqlServerDataSourceImpl,
10
+ SqlServerSqlDialect,
11
+ )
12
+
13
+
14
+ class SqlServerDataSourceTestHelper(DataSourceTestHelper):
15
+ def _create_database_name(self) -> Optional[str]:
16
+ return os.getenv("SQLSERVER_DATABASE", "master")
17
+
18
+ def _create_data_source_yaml_str(self) -> str:
19
+ """
20
+ Called in _create_data_source_impl to initialized self.data_source_impl
21
+ self.database_name and self.schema_name are available if appropriate for the data source type
22
+ """
23
+ return f"""
24
+ type: sqlserver
25
+ name: {self.name}
26
+ connection:
27
+ host: '{os.getenv("SQLSERVER_HOST", "localhost")}'
28
+ port: '{os.getenv("SQLSERVER_PORT", "1433")}'
29
+ database: '{os.getenv("SQLSERVER_DATABASE", "master")}'
30
+ user: '{os.getenv("SQLSERVER_USERNAME", "SA")}'
31
+ password: '{os.getenv("SQLSERVER_PASSWORD", "Password1!")}'
32
+ trust_server_certificate: true
33
+ driver: '{os.getenv("SQLSERVER_DRIVER", "ODBC Driver 18 for SQL Server")}'
34
+ """
35
+
36
+ def drop_test_schema_if_exists(self) -> None:
37
+ """We overwrite this function because the old query in soda-library is a bit unreadable and does not work with Synapse.
38
+ The logic is the same: drop all tables, and then drop the schema if it exists.
39
+ This is a more "manual" approach, but it is more readable and works with Synapse."""
40
+ # First find all the tables in the schema
41
+ table_names: list[str] = self.query_existing_test_tables()
42
+ data_source_impl: SqlServerDataSourceImpl = self.data_source_impl
43
+ dialect: SqlServerSqlDialect = data_source_impl.sql_dialect
44
+ for fully_qualified_table_name in table_names:
45
+ table_identifier = f"{dialect.quote_default(fully_qualified_table_name.database_name)}.{dialect.quote_default(fully_qualified_table_name.schema_name)}.{dialect.quote_default(fully_qualified_table_name.table_name)}"
46
+ drop_table_sql = dialect.build_drop_table_sql(DROP_TABLE(table_identifier))
47
+ self.data_source_impl.execute_update(drop_table_sql)
48
+
49
+ view_names: list[str] = self.query_existing_test_views()
50
+ for fully_qualified_view_name in view_names:
51
+ view_identifier = f"{dialect.quote_default(fully_qualified_view_name.database_name)}.{dialect.quote_default(fully_qualified_view_name.schema_name)}.{dialect.quote_default(fully_qualified_view_name.view_name)}"
52
+ drop_view_sql = dialect.build_drop_view_sql(DROP_VIEW(view_identifier))
53
+ self.data_source_impl.execute_update(drop_view_sql)
54
+
55
+ # Drop the schema if it exists.
56
+ schema_name = self.extract_schema_from_prefix()
57
+ if self._does_schema_exist(schema_name):
58
+ self.data_source_impl.execute_update(f"DROP SCHEMA {dialect.quote_default(schema_name)};")
59
+
60
+ def _does_schema_exist(self, schema_name: str) -> bool:
61
+ """Check if the schema exists in the database."""
62
+ query_result = self.data_source_impl.execute_query(
63
+ f"SELECT name FROM sys.schemas WHERE name = '{schema_name}';"
64
+ )
65
+ return len(query_result.rows) > 0
@@ -0,0 +1,6 @@
1
+ Metadata-Version: 2.4
2
+ Name: soda-sqlserver
3
+ Version: 4.0.5
4
+ Requires-Dist: soda-core==4.0.5
5
+ Requires-Dist: pyodbc
6
+ Dynamic: requires-dist
@@ -0,0 +1,10 @@
1
+ setup.py
2
+ src/soda_sqlserver.egg-info/PKG-INFO
3
+ src/soda_sqlserver.egg-info/SOURCES.txt
4
+ src/soda_sqlserver.egg-info/dependency_links.txt
5
+ src/soda_sqlserver.egg-info/entry_points.txt
6
+ src/soda_sqlserver.egg-info/requires.txt
7
+ src/soda_sqlserver.egg-info/top_level.txt
8
+ src/soda_sqlserver/common/data_sources/sqlserver_data_source.py
9
+ src/soda_sqlserver/common/data_sources/sqlserver_data_source_connection.py
10
+ src/soda_sqlserver/test_helpers/sqlserver_data_source_test_helper.py
@@ -0,0 +1,2 @@
1
+ [soda.plugins.data_source.sqlserver]
2
+ SqlServerDataSourceImpl = soda_sqlserver.common.data_sources.sqlserver_data_source:SqlServerDataSourceImpl
@@ -0,0 +1,2 @@
1
+ soda-core==4.0.5
2
+ pyodbc
@@ -0,0 +1 @@
1
+ soda_sqlserver