databricks-sqlalchemy 1.0.2__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- CHANGELOG.md +271 -2
- databricks/sqlalchemy/__init__.py +4 -1
- databricks/sqlalchemy/_ddl.py +100 -0
- databricks/sqlalchemy/_parse.py +385 -0
- databricks/sqlalchemy/_types.py +323 -0
- databricks/sqlalchemy/base.py +436 -0
- databricks/sqlalchemy/dependency_test/test_dependency.py +22 -0
- databricks/sqlalchemy/py.typed +0 -0
- databricks/sqlalchemy/pytest.ini +4 -0
- databricks/sqlalchemy/requirements.py +249 -0
- databricks/sqlalchemy/setup.cfg +4 -0
- databricks/sqlalchemy/test/_extra.py +70 -0
- databricks/sqlalchemy/test/_future.py +331 -0
- databricks/sqlalchemy/test/_regression.py +311 -0
- databricks/sqlalchemy/test/_unsupported.py +450 -0
- databricks/sqlalchemy/test/conftest.py +13 -0
- databricks/sqlalchemy/test/overrides/_componentreflectiontest.py +189 -0
- databricks/sqlalchemy/test/overrides/_ctetest.py +33 -0
- databricks/sqlalchemy/test/test_suite.py +13 -0
- databricks/sqlalchemy/test_local/__init__.py +5 -0
- databricks/sqlalchemy/test_local/conftest.py +44 -0
- databricks/sqlalchemy/test_local/e2e/MOCK_DATA.xlsx +0 -0
- databricks/sqlalchemy/test_local/e2e/test_basic.py +543 -0
- databricks/sqlalchemy/test_local/test_ddl.py +96 -0
- databricks/sqlalchemy/test_local/test_parsing.py +160 -0
- databricks/sqlalchemy/test_local/test_types.py +161 -0
- {databricks_sqlalchemy-1.0.2.dist-info → databricks_sqlalchemy-2.0.1.dist-info}/METADATA +60 -39
- databricks_sqlalchemy-2.0.1.dist-info/RECORD +31 -0
- databricks/sqlalchemy/dialect/__init__.py +0 -340
- databricks/sqlalchemy/dialect/base.py +0 -17
- databricks/sqlalchemy/dialect/compiler.py +0 -38
- databricks_sqlalchemy-1.0.2.dist-info/RECORD +0 -10
- {databricks_sqlalchemy-1.0.2.dist-info → databricks_sqlalchemy-2.0.1.dist-info}/LICENSE +0 -0
- {databricks_sqlalchemy-1.0.2.dist-info → databricks_sqlalchemy-2.0.1.dist-info}/WHEEL +0 -0
- {databricks_sqlalchemy-1.0.2.dist-info → databricks_sqlalchemy-2.0.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,385 @@
|
|
1
|
+
from typing import List, Optional, Dict
|
2
|
+
import re
|
3
|
+
|
4
|
+
import sqlalchemy
|
5
|
+
from sqlalchemy.engine import CursorResult
|
6
|
+
from sqlalchemy.engine.interfaces import ReflectedColumn
|
7
|
+
|
8
|
+
from databricks.sqlalchemy import _types as type_overrides
|
9
|
+
|
10
|
+
"""
|
11
|
+
This module contains helper functions that can parse the contents
|
12
|
+
of metadata and exceptions received from DBR. These are mostly just
|
13
|
+
wrappers around regexes.
|
14
|
+
"""
|
15
|
+
|
16
|
+
|
17
|
+
class DatabricksSqlAlchemyParseException(Exception):
|
18
|
+
pass
|
19
|
+
|
20
|
+
|
21
|
+
def _match_table_not_found_string(message: str) -> bool:
|
22
|
+
"""Return True if the message contains a substring indicating that a table was not found"""
|
23
|
+
|
24
|
+
DBR_LTE_12_NOT_FOUND_STRING = "Table or view not found"
|
25
|
+
DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND"
|
26
|
+
return any(
|
27
|
+
[
|
28
|
+
DBR_LTE_12_NOT_FOUND_STRING in message,
|
29
|
+
DBR_GT_12_NOT_FOUND_STRING in message,
|
30
|
+
]
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
def _describe_table_extended_result_to_dict_list(
|
35
|
+
result: CursorResult,
|
36
|
+
) -> List[Dict[str, str]]:
|
37
|
+
"""Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries"""
|
38
|
+
|
39
|
+
rows_to_return = []
|
40
|
+
for row in result.all():
|
41
|
+
this_row = {"col_name": row.col_name, "data_type": row.data_type}
|
42
|
+
rows_to_return.append(this_row)
|
43
|
+
|
44
|
+
return rows_to_return
|
45
|
+
|
46
|
+
|
47
|
+
def extract_identifiers_from_string(input_str: str) -> List[str]:
|
48
|
+
"""For a string input resembling (`a`, `b`, `c`) return a list of identifiers ['a', 'b', 'c']"""
|
49
|
+
|
50
|
+
# This matches the valid character list contained in DatabricksIdentifierPreparer
|
51
|
+
pattern = re.compile(r"`([A-Za-z0-9_]+)`")
|
52
|
+
matches = pattern.findall(input_str)
|
53
|
+
return [i for i in matches]
|
54
|
+
|
55
|
+
|
56
|
+
def extract_identifier_groups_from_string(input_str: str) -> List[str]:
|
57
|
+
"""For a string input resembling :
|
58
|
+
|
59
|
+
FOREIGN KEY (`pname`, `pid`, `pattr`) REFERENCES `main`.`pysql_sqlalchemy`.`tb1` (`name`, `id`, `attr`)
|
60
|
+
|
61
|
+
Return ['(`pname`, `pid`, `pattr`)', '(`name`, `id`, `attr`)']
|
62
|
+
"""
|
63
|
+
pattern = re.compile(r"\([`A-Za-z0-9_,\s]*\)")
|
64
|
+
matches = pattern.findall(input_str)
|
65
|
+
return [i for i in matches]
|
66
|
+
|
67
|
+
|
68
|
+
def extract_three_level_identifier_from_constraint_string(input_str: str) -> dict:
|
69
|
+
"""For a string input resembling :
|
70
|
+
FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`pysql_dialect_compliance`.`users` (`user_id`)
|
71
|
+
|
72
|
+
Return a dict like
|
73
|
+
{
|
74
|
+
"catalog": "main",
|
75
|
+
"schema": "pysql_dialect_compliance",
|
76
|
+
"table": "users"
|
77
|
+
}
|
78
|
+
|
79
|
+
Raise a DatabricksSqlAlchemyParseException if a 3L namespace isn't found
|
80
|
+
"""
|
81
|
+
pat = re.compile(r"REFERENCES\s+(.*?)\s*\(")
|
82
|
+
matches = pat.findall(input_str)
|
83
|
+
|
84
|
+
if not matches:
|
85
|
+
raise DatabricksSqlAlchemyParseException(
|
86
|
+
"3L namespace not found in constraint string"
|
87
|
+
)
|
88
|
+
|
89
|
+
first_match = matches[0]
|
90
|
+
parts = first_match.split(".")
|
91
|
+
|
92
|
+
def strip_backticks(input: str):
|
93
|
+
return input.replace("`", "")
|
94
|
+
|
95
|
+
try:
|
96
|
+
return {
|
97
|
+
"catalog": strip_backticks(parts[0]),
|
98
|
+
"schema": strip_backticks(parts[1]),
|
99
|
+
"table": strip_backticks(parts[2]),
|
100
|
+
}
|
101
|
+
except IndexError:
|
102
|
+
raise DatabricksSqlAlchemyParseException(
|
103
|
+
"Incomplete 3L namespace found in constraint string: " + ".".join(parts)
|
104
|
+
)
|
105
|
+
|
106
|
+
|
107
|
+
def _parse_fk_from_constraint_string(constraint_str: str) -> dict:
|
108
|
+
"""Build a dictionary of foreign key constraint information from a constraint string.
|
109
|
+
|
110
|
+
For example:
|
111
|
+
|
112
|
+
```
|
113
|
+
FOREIGN KEY (`pname`, `pid`, `pattr`) REFERENCES `main`.`pysql_dialect_compliance`.`tb1` (`name`, `id`, `attr`)
|
114
|
+
```
|
115
|
+
|
116
|
+
Return a dictionary like:
|
117
|
+
|
118
|
+
```
|
119
|
+
{
|
120
|
+
"constrained_columns": ["pname", "pid", "pattr"],
|
121
|
+
"referred_table": "tb1",
|
122
|
+
"referred_schema": "pysql_dialect_compliance",
|
123
|
+
"referred_columns": ["name", "id", "attr"]
|
124
|
+
}
|
125
|
+
```
|
126
|
+
|
127
|
+
Note that the constraint name doesn't appear in the constraint string so it will not
|
128
|
+
be present in the output of this function.
|
129
|
+
"""
|
130
|
+
|
131
|
+
referred_table_dict = extract_three_level_identifier_from_constraint_string(
|
132
|
+
constraint_str
|
133
|
+
)
|
134
|
+
referred_table = referred_table_dict["table"]
|
135
|
+
referred_schema = referred_table_dict["schema"]
|
136
|
+
|
137
|
+
# _extracted is a tuple of two lists of identifiers
|
138
|
+
# we assume the first immediately follows "FOREIGN KEY" and the second
|
139
|
+
# immediately follows REFERENCES $tableName
|
140
|
+
_extracted = extract_identifier_groups_from_string(constraint_str)
|
141
|
+
constrained_columns_str, referred_columns_str = (
|
142
|
+
_extracted[0],
|
143
|
+
_extracted[1],
|
144
|
+
)
|
145
|
+
|
146
|
+
constrained_columns = extract_identifiers_from_string(constrained_columns_str)
|
147
|
+
referred_columns = extract_identifiers_from_string(referred_columns_str)
|
148
|
+
|
149
|
+
return {
|
150
|
+
"constrained_columns": constrained_columns,
|
151
|
+
"referred_table": referred_table,
|
152
|
+
"referred_columns": referred_columns,
|
153
|
+
"referred_schema": referred_schema,
|
154
|
+
}
|
155
|
+
|
156
|
+
|
157
|
+
def build_fk_dict(
|
158
|
+
fk_name: str, fk_constraint_string: str, schema_name: Optional[str]
|
159
|
+
) -> dict:
|
160
|
+
"""
|
161
|
+
Given a foriegn key name and a foreign key constraint string, return a dictionary
|
162
|
+
with the following keys:
|
163
|
+
|
164
|
+
name
|
165
|
+
the name of the foreign key constraint
|
166
|
+
constrained_columns
|
167
|
+
a list of column names that make up the foreign key
|
168
|
+
referred_table
|
169
|
+
the name of the table that the foreign key references
|
170
|
+
referred_columns
|
171
|
+
a list of column names that are referenced by the foreign key
|
172
|
+
referred_schema
|
173
|
+
the name of the schema that the foreign key references.
|
174
|
+
|
175
|
+
referred schema will be None if the schema_name argument is None.
|
176
|
+
This is required by SQLAlchey's ComponentReflectionTest::test_get_foreign_keys
|
177
|
+
"""
|
178
|
+
|
179
|
+
# The foreign key name is not contained in the constraint string so we
|
180
|
+
# need to add it manually
|
181
|
+
base_fk_dict = _parse_fk_from_constraint_string(fk_constraint_string)
|
182
|
+
|
183
|
+
if not schema_name:
|
184
|
+
schema_override_dict = dict(referred_schema=None)
|
185
|
+
else:
|
186
|
+
schema_override_dict = {}
|
187
|
+
|
188
|
+
# mypy doesn't like this method of conditionally adding a key to a dictionary
|
189
|
+
# while keeping everything immutable
|
190
|
+
complete_foreign_key_dict = {
|
191
|
+
"name": fk_name,
|
192
|
+
**base_fk_dict,
|
193
|
+
**schema_override_dict, # type: ignore
|
194
|
+
}
|
195
|
+
|
196
|
+
return complete_foreign_key_dict
|
197
|
+
|
198
|
+
|
199
|
+
def _parse_pk_columns_from_constraint_string(constraint_str: str) -> List[str]:
|
200
|
+
"""Build a list of constrained columns from a constraint string returned by DESCRIBE TABLE EXTENDED
|
201
|
+
|
202
|
+
For example:
|
203
|
+
|
204
|
+
PRIMARY KEY (`id`, `name`, `email_address`)
|
205
|
+
|
206
|
+
Returns a list like
|
207
|
+
|
208
|
+
["id", "name", "email_address"]
|
209
|
+
"""
|
210
|
+
|
211
|
+
_extracted = extract_identifiers_from_string(constraint_str)
|
212
|
+
|
213
|
+
return _extracted
|
214
|
+
|
215
|
+
|
216
|
+
def build_pk_dict(pk_name: str, pk_constraint_string: str) -> dict:
|
217
|
+
"""Given a primary key name and a primary key constraint string, return a dictionary
|
218
|
+
with the following keys:
|
219
|
+
|
220
|
+
constrained_columns
|
221
|
+
A list of string column names that make up the primary key
|
222
|
+
|
223
|
+
name
|
224
|
+
The name of the primary key constraint
|
225
|
+
"""
|
226
|
+
|
227
|
+
constrained_columns = _parse_pk_columns_from_constraint_string(pk_constraint_string)
|
228
|
+
|
229
|
+
return {"constrained_columns": constrained_columns, "name": pk_name}
|
230
|
+
|
231
|
+
|
232
|
+
def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> List[dict]:
|
233
|
+
"""Return a list of dictionaries containing only the col_name:data_type pairs where the `data_type`
|
234
|
+
value contains the match argument.
|
235
|
+
|
236
|
+
Today, DESCRIBE TABLE EXTENDED doesn't give a deterministic name to the fields
|
237
|
+
a constraint will be found in its output. So we cycle through its output looking
|
238
|
+
for a match. This is brittle. We could optionally make two roundtrips: the first
|
239
|
+
would query information_schema for the name of the constraint on this table, and
|
240
|
+
a second to DESCRIBE TABLE EXTENDED, at which point we would know the name of the
|
241
|
+
constraint. But for now we instead assume that Python list comprehension is faster
|
242
|
+
than a network roundtrip
|
243
|
+
"""
|
244
|
+
|
245
|
+
output_rows = []
|
246
|
+
|
247
|
+
for row_dict in dte_output:
|
248
|
+
if match in row_dict["data_type"]:
|
249
|
+
output_rows.append(row_dict)
|
250
|
+
|
251
|
+
return output_rows
|
252
|
+
|
253
|
+
|
254
|
+
def match_dte_rows_by_key(dte_output: List[Dict[str, str]], match: str) -> List[dict]:
|
255
|
+
"""Return a list of dictionaries containing only the col_name:data_type pairs where the `col_name`
|
256
|
+
value contains the match argument.
|
257
|
+
"""
|
258
|
+
|
259
|
+
output_rows = []
|
260
|
+
|
261
|
+
for row_dict in dte_output:
|
262
|
+
if match in row_dict["col_name"]:
|
263
|
+
output_rows.append(row_dict)
|
264
|
+
|
265
|
+
return output_rows
|
266
|
+
|
267
|
+
|
268
|
+
def get_fk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> List[dict]:
|
269
|
+
"""If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries,
|
270
|
+
one dictionary per defined constraint
|
271
|
+
"""
|
272
|
+
|
273
|
+
output = match_dte_rows_by_value(dte_output, "FOREIGN KEY")
|
274
|
+
|
275
|
+
return output
|
276
|
+
|
277
|
+
|
278
|
+
def get_pk_strings_from_dte_output(
|
279
|
+
dte_output: List[Dict[str, str]]
|
280
|
+
) -> Optional[List[dict]]:
|
281
|
+
"""If the DESCRIBE TABLE EXTENDED output contains primary key constraints, return a list of dictionaries,
|
282
|
+
one dictionary per defined constraint.
|
283
|
+
|
284
|
+
Returns None if no primary key constraints are found.
|
285
|
+
"""
|
286
|
+
|
287
|
+
output = match_dte_rows_by_value(dte_output, "PRIMARY KEY")
|
288
|
+
|
289
|
+
return output
|
290
|
+
|
291
|
+
|
292
|
+
def get_comment_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[str]:
|
293
|
+
"""Returns the value of the first "Comment" col_name data in dte_output"""
|
294
|
+
output = match_dte_rows_by_key(dte_output, "Comment")
|
295
|
+
if not output:
|
296
|
+
return None
|
297
|
+
else:
|
298
|
+
return output[0]["data_type"]
|
299
|
+
|
300
|
+
|
301
|
+
# The keys of this dictionary are the values we expect to see in a
|
302
|
+
# TGetColumnsRequest's .TYPE_NAME attribute.
|
303
|
+
# These are enumerated in ttypes.py as class TTypeId.
|
304
|
+
# TODO: confirm that all types in TTypeId are included here.
|
305
|
+
GET_COLUMNS_TYPE_MAP = {
|
306
|
+
"boolean": sqlalchemy.types.Boolean,
|
307
|
+
"smallint": sqlalchemy.types.SmallInteger,
|
308
|
+
"tinyint": type_overrides.TINYINT,
|
309
|
+
"int": sqlalchemy.types.Integer,
|
310
|
+
"bigint": sqlalchemy.types.BigInteger,
|
311
|
+
"float": sqlalchemy.types.Float,
|
312
|
+
"double": sqlalchemy.types.Float,
|
313
|
+
"string": sqlalchemy.types.String,
|
314
|
+
"varchar": sqlalchemy.types.String,
|
315
|
+
"char": sqlalchemy.types.String,
|
316
|
+
"binary": sqlalchemy.types.String,
|
317
|
+
"array": sqlalchemy.types.String,
|
318
|
+
"map": sqlalchemy.types.String,
|
319
|
+
"struct": sqlalchemy.types.String,
|
320
|
+
"uniontype": sqlalchemy.types.String,
|
321
|
+
"decimal": sqlalchemy.types.Numeric,
|
322
|
+
"timestamp": type_overrides.TIMESTAMP,
|
323
|
+
"timestamp_ntz": type_overrides.TIMESTAMP_NTZ,
|
324
|
+
"date": sqlalchemy.types.Date,
|
325
|
+
}
|
326
|
+
|
327
|
+
|
328
|
+
def parse_numeric_type_precision_and_scale(type_name_str):
|
329
|
+
"""Return an intantiated sqlalchemy Numeric() type that preserves the precision and scale indicated
|
330
|
+
in the output from TGetColumnsRequest.
|
331
|
+
|
332
|
+
type_name_str
|
333
|
+
The value of TGetColumnsReq.TYPE_NAME.
|
334
|
+
|
335
|
+
If type_name_str is "DECIMAL(18,5) returns sqlalchemy.types.Numeric(18,5)
|
336
|
+
"""
|
337
|
+
|
338
|
+
pattern = re.compile(r"DECIMAL\((\d+,\d+)\)")
|
339
|
+
match = re.search(pattern, type_name_str)
|
340
|
+
precision_and_scale = match.group(1)
|
341
|
+
precision, scale = tuple(precision_and_scale.split(","))
|
342
|
+
|
343
|
+
return sqlalchemy.types.Numeric(int(precision), int(scale))
|
344
|
+
|
345
|
+
|
346
|
+
def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColumn:
|
347
|
+
"""Returns a dictionary of the ReflectedColumn schema parsed from
|
348
|
+
a single of the result of a TGetColumnsRequest thrift RPC
|
349
|
+
"""
|
350
|
+
|
351
|
+
pat = re.compile(r"^\w+")
|
352
|
+
|
353
|
+
# This method assumes a valid TYPE_NAME field in the response.
|
354
|
+
# TODO: add error handling in case TGetColumnsResponse format changes
|
355
|
+
|
356
|
+
_raw_col_type = re.search(pat, thrift_resp_row.TYPE_NAME).group(0).lower() # type: ignore
|
357
|
+
_col_type = GET_COLUMNS_TYPE_MAP[_raw_col_type]
|
358
|
+
|
359
|
+
if _raw_col_type == "decimal":
|
360
|
+
final_col_type = parse_numeric_type_precision_and_scale(
|
361
|
+
thrift_resp_row.TYPE_NAME
|
362
|
+
)
|
363
|
+
else:
|
364
|
+
final_col_type = _col_type
|
365
|
+
|
366
|
+
# See comments about autoincrement in test_suite.py
|
367
|
+
# Since Databricks SQL doesn't currently support inline AUTOINCREMENT declarations
|
368
|
+
# the autoincrement must be manually declared with an Identity() construct in SQLAlchemy
|
369
|
+
# Other dialects can perform this extra Identity() step automatically. But that is not
|
370
|
+
# implemented in the Databricks dialect right now. So autoincrement is currently always False.
|
371
|
+
# It's not clear what IS_AUTO_INCREMENT in the thrift response actually reflects or whether
|
372
|
+
# it ever returns a `YES`.
|
373
|
+
|
374
|
+
# Per the guidance in SQLAlchemy's docstrings, we prefer to not even include an autoincrement
|
375
|
+
# key in this dictionary.
|
376
|
+
this_column = {
|
377
|
+
"name": thrift_resp_row.COLUMN_NAME,
|
378
|
+
"type": final_col_type,
|
379
|
+
"nullable": bool(thrift_resp_row.NULLABLE),
|
380
|
+
"default": thrift_resp_row.COLUMN_DEF,
|
381
|
+
"comment": thrift_resp_row.REMARKS or None,
|
382
|
+
}
|
383
|
+
|
384
|
+
# TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects
|
385
|
+
return this_column # type: ignore
|
@@ -0,0 +1,323 @@
|
|
1
|
+
from datetime import datetime, time, timezone
|
2
|
+
from itertools import product
|
3
|
+
from typing import Any, Union, Optional
|
4
|
+
|
5
|
+
import sqlalchemy
|
6
|
+
from sqlalchemy.engine.interfaces import Dialect
|
7
|
+
from sqlalchemy.ext.compiler import compiles
|
8
|
+
|
9
|
+
from databricks.sql.utils import ParamEscaper
|
10
|
+
|
11
|
+
|
12
|
+
def process_literal_param_hack(value: Any):
|
13
|
+
"""This method is supposed to accept a Python type and return a string representation of that type.
|
14
|
+
But due to some weirdness in the way SQLAlchemy's literal rendering works, we have to return
|
15
|
+
the value itself because, by the time it reaches our custom type code, it's already been converted
|
16
|
+
into a string.
|
17
|
+
|
18
|
+
TimeTest
|
19
|
+
DateTimeTest
|
20
|
+
DateTimeTZTest
|
21
|
+
|
22
|
+
This dynamic only seems to affect the literal rendering of datetime and time objects.
|
23
|
+
|
24
|
+
All fail without this hack in-place. I'm not sure why. But it works.
|
25
|
+
"""
|
26
|
+
return value
|
27
|
+
|
28
|
+
|
29
|
+
@compiles(sqlalchemy.types.Enum, "databricks")
|
30
|
+
@compiles(sqlalchemy.types.String, "databricks")
|
31
|
+
@compiles(sqlalchemy.types.Text, "databricks")
|
32
|
+
@compiles(sqlalchemy.types.Time, "databricks")
|
33
|
+
@compiles(sqlalchemy.types.Unicode, "databricks")
|
34
|
+
@compiles(sqlalchemy.types.UnicodeText, "databricks")
|
35
|
+
@compiles(sqlalchemy.types.Uuid, "databricks")
|
36
|
+
def compile_string_databricks(type_, compiler, **kw):
|
37
|
+
"""
|
38
|
+
We override the default compilation for Enum(), String(), Text(), and Time() because SQLAlchemy
|
39
|
+
defaults to incompatible / abnormal compiled names
|
40
|
+
|
41
|
+
Enum -> VARCHAR
|
42
|
+
String -> VARCHAR[LENGTH]
|
43
|
+
Text -> VARCHAR[LENGTH]
|
44
|
+
Time -> TIME
|
45
|
+
Unicode -> VARCHAR[LENGTH]
|
46
|
+
UnicodeText -> TEXT
|
47
|
+
Uuid -> CHAR[32]
|
48
|
+
|
49
|
+
But all of these types will be compiled to STRING in Databricks SQL
|
50
|
+
"""
|
51
|
+
return "STRING"
|
52
|
+
|
53
|
+
|
54
|
+
@compiles(sqlalchemy.types.Integer, "databricks")
|
55
|
+
def compile_integer_databricks(type_, compiler, **kw):
|
56
|
+
"""
|
57
|
+
We need to override the default Integer compilation rendering because Databricks uses "INT" instead of "INTEGER"
|
58
|
+
"""
|
59
|
+
return "INT"
|
60
|
+
|
61
|
+
|
62
|
+
@compiles(sqlalchemy.types.LargeBinary, "databricks")
|
63
|
+
def compile_binary_databricks(type_, compiler, **kw):
|
64
|
+
"""
|
65
|
+
We need to override the default LargeBinary compilation rendering because Databricks uses "BINARY" instead of "BLOB"
|
66
|
+
"""
|
67
|
+
return "BINARY"
|
68
|
+
|
69
|
+
|
70
|
+
@compiles(sqlalchemy.types.Numeric, "databricks")
|
71
|
+
def compile_numeric_databricks(type_, compiler, **kw):
|
72
|
+
"""
|
73
|
+
We need to override the default Numeric compilation rendering because Databricks uses "DECIMAL" instead of "NUMERIC"
|
74
|
+
|
75
|
+
The built-in visit_DECIMAL behaviour captures the precision and scale. Here we're just mapping calls to compile Numeric
|
76
|
+
to the SQLAlchemy Decimal() implementation
|
77
|
+
"""
|
78
|
+
return compiler.visit_DECIMAL(type_, **kw)
|
79
|
+
|
80
|
+
|
81
|
+
@compiles(sqlalchemy.types.DateTime, "databricks")
|
82
|
+
def compile_datetime_databricks(type_, compiler, **kw):
|
83
|
+
"""
|
84
|
+
We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP_NTZ" instead of "DATETIME"
|
85
|
+
"""
|
86
|
+
return "TIMESTAMP_NTZ"
|
87
|
+
|
88
|
+
|
89
|
+
@compiles(sqlalchemy.types.ARRAY, "databricks")
|
90
|
+
def compile_array_databricks(type_, compiler, **kw):
|
91
|
+
"""
|
92
|
+
SQLAlchemy's default ARRAY can't compile as it's only implemented for Postgresql.
|
93
|
+
The Postgres implementation works for Databricks SQL, so we duplicate that here.
|
94
|
+
|
95
|
+
:type_:
|
96
|
+
This is an instance of sqlalchemy.types.ARRAY which always includes an item_type attribute
|
97
|
+
which is itself an instance of TypeEngine
|
98
|
+
|
99
|
+
https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.ARRAY
|
100
|
+
"""
|
101
|
+
|
102
|
+
inner = compiler.process(type_.item_type, **kw)
|
103
|
+
|
104
|
+
return f"ARRAY<{inner}>"
|
105
|
+
|
106
|
+
|
107
|
+
class TIMESTAMP_NTZ(sqlalchemy.types.TypeDecorator):
|
108
|
+
"""Represents values comprising values of fields year, month, day, hour, minute, and second.
|
109
|
+
All operations are performed without taking any time zone into account.
|
110
|
+
|
111
|
+
Our dialect maps sqlalchemy.types.DateTime() to this type, which means that all DateTime()
|
112
|
+
objects are stored without tzinfo. To read and write timezone-aware datetimes use
|
113
|
+
databricks.sql.TIMESTAMP instead.
|
114
|
+
|
115
|
+
https://docs.databricks.com/en/sql/language-manual/data-types/timestamp-ntz-type.html
|
116
|
+
"""
|
117
|
+
|
118
|
+
impl = sqlalchemy.types.DateTime
|
119
|
+
|
120
|
+
cache_ok = True
|
121
|
+
|
122
|
+
def process_result_value(self, value: Union[None, datetime], dialect):
|
123
|
+
if value is None:
|
124
|
+
return None
|
125
|
+
return value.replace(tzinfo=None)
|
126
|
+
|
127
|
+
|
128
|
+
class TIMESTAMP(sqlalchemy.types.TypeDecorator):
|
129
|
+
"""Represents values comprising values of fields year, month, day, hour, minute, and second,
|
130
|
+
with the session local time-zone.
|
131
|
+
|
132
|
+
Our dialect maps sqlalchemy.types.DateTime() to TIMESTAMP_NTZ, which means that all DateTime()
|
133
|
+
objects are stored without tzinfo. To read and write timezone-aware datetimes use
|
134
|
+
this type instead.
|
135
|
+
|
136
|
+
```python
|
137
|
+
# This won't work
|
138
|
+
`Column(sqlalchemy.DateTime(timezone=True))`
|
139
|
+
|
140
|
+
# But this does
|
141
|
+
`Column(TIMESTAMP)`
|
142
|
+
````
|
143
|
+
|
144
|
+
https://docs.databricks.com/en/sql/language-manual/data-types/timestamp-type.html
|
145
|
+
"""
|
146
|
+
|
147
|
+
impl = sqlalchemy.types.DateTime
|
148
|
+
|
149
|
+
cache_ok = True
|
150
|
+
|
151
|
+
def process_result_value(self, value: Union[None, datetime], dialect):
|
152
|
+
if value is None:
|
153
|
+
return None
|
154
|
+
|
155
|
+
if not value.tzinfo:
|
156
|
+
return value.replace(tzinfo=timezone.utc)
|
157
|
+
return value
|
158
|
+
|
159
|
+
def process_bind_param(
|
160
|
+
self, value: Union[datetime, None], dialect
|
161
|
+
) -> Optional[datetime]:
|
162
|
+
"""pysql can pass datetime.datetime() objects directly to DBR"""
|
163
|
+
return value
|
164
|
+
|
165
|
+
def process_literal_param(
|
166
|
+
self, value: Union[datetime, None], dialect: Dialect
|
167
|
+
) -> str:
|
168
|
+
""" """
|
169
|
+
return process_literal_param_hack(value)
|
170
|
+
|
171
|
+
|
172
|
+
@compiles(TIMESTAMP, "databricks")
|
173
|
+
def compile_timestamp_databricks(type_, compiler, **kw):
|
174
|
+
"""
|
175
|
+
We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP_NTZ" instead of "DATETIME"
|
176
|
+
"""
|
177
|
+
return "TIMESTAMP"
|
178
|
+
|
179
|
+
|
180
|
+
class DatabricksTimeType(sqlalchemy.types.TypeDecorator):
|
181
|
+
"""Databricks has no native TIME type. So we store it as a string."""
|
182
|
+
|
183
|
+
impl = sqlalchemy.types.Time
|
184
|
+
cache_ok = True
|
185
|
+
|
186
|
+
BASE_FMT = "%H:%M:%S"
|
187
|
+
MICROSEC_PART = ".%f"
|
188
|
+
TIMEZONE_PART = "%z"
|
189
|
+
|
190
|
+
def _generate_fmt_string(self, ms: bool, tz: bool) -> str:
|
191
|
+
"""Return a format string for datetime.strptime() that includes or excludes microseconds and timezone."""
|
192
|
+
_ = lambda x, y: x if y else ""
|
193
|
+
return f"{self.BASE_FMT}{_(self.MICROSEC_PART,ms)}{_(self.TIMEZONE_PART,tz)}"
|
194
|
+
|
195
|
+
@property
|
196
|
+
def allowed_fmt_strings(self):
|
197
|
+
"""Time strings can be read with or without microseconds and with or without a timezone."""
|
198
|
+
|
199
|
+
if not hasattr(self, "_allowed_fmt_strings"):
|
200
|
+
ms_switch = tz_switch = [True, False]
|
201
|
+
self._allowed_fmt_strings = [
|
202
|
+
self._generate_fmt_string(x, y)
|
203
|
+
for x, y in product(ms_switch, tz_switch)
|
204
|
+
]
|
205
|
+
|
206
|
+
return self._allowed_fmt_strings
|
207
|
+
|
208
|
+
def _parse_result_string(self, value: str) -> time:
|
209
|
+
"""Parse a string into a time object. Try all allowed formats until one works."""
|
210
|
+
for fmt in self.allowed_fmt_strings:
|
211
|
+
try:
|
212
|
+
# We use timetz() here because we want to preserve the timezone information
|
213
|
+
# Calling .time() will strip the timezone information
|
214
|
+
return datetime.strptime(value, fmt).timetz()
|
215
|
+
except ValueError:
|
216
|
+
pass
|
217
|
+
|
218
|
+
raise ValueError(f"Could not parse time string {value}")
|
219
|
+
|
220
|
+
def _determine_fmt_string(self, value: time) -> str:
|
221
|
+
"""Determine which format string to use to render a time object as a string."""
|
222
|
+
ms_bool = value.microsecond > 0
|
223
|
+
tz_bool = value.tzinfo is not None
|
224
|
+
return self._generate_fmt_string(ms_bool, tz_bool)
|
225
|
+
|
226
|
+
def process_bind_param(self, value: Union[time, None], dialect) -> Union[None, str]:
|
227
|
+
"""Values sent to the database are converted to %:H:%M:%S strings."""
|
228
|
+
if value is None:
|
229
|
+
return None
|
230
|
+
fmt_string = self._determine_fmt_string(value)
|
231
|
+
return value.strftime(fmt_string)
|
232
|
+
|
233
|
+
# mypy doesn't like this workaround because TypeEngine wants process_literal_param to return a string
|
234
|
+
def process_literal_param(self, value, dialect) -> time: # type: ignore
|
235
|
+
""" """
|
236
|
+
return process_literal_param_hack(value)
|
237
|
+
|
238
|
+
def process_result_value(
|
239
|
+
self, value: Union[None, str], dialect
|
240
|
+
) -> Union[time, None]:
|
241
|
+
"""Values received from the database are parsed into datetime.time() objects"""
|
242
|
+
if value is None:
|
243
|
+
return None
|
244
|
+
|
245
|
+
return self._parse_result_string(value)
|
246
|
+
|
247
|
+
|
248
|
+
class DatabricksStringType(sqlalchemy.types.TypeDecorator):
|
249
|
+
"""We have to implement our own String() type because SQLAlchemy's default implementation
|
250
|
+
wants to escape single-quotes with a doubled single-quote. Databricks uses a backslash for
|
251
|
+
escaping of literal strings. And SQLAlchemy's default escaping breaks Databricks SQL.
|
252
|
+
"""
|
253
|
+
|
254
|
+
impl = sqlalchemy.types.String
|
255
|
+
cache_ok = True
|
256
|
+
pe = ParamEscaper()
|
257
|
+
|
258
|
+
def process_literal_param(self, value, dialect) -> str:
|
259
|
+
"""SQLAlchemy's default string escaping for backslashes doesn't work for databricks. The logic here
|
260
|
+
implements the same logic as our legacy inline escaping logic.
|
261
|
+
"""
|
262
|
+
|
263
|
+
return self.pe.escape_string(value)
|
264
|
+
|
265
|
+
def literal_processor(self, dialect):
|
266
|
+
"""We manually override this method to prevent further processing of the string literal beyond
|
267
|
+
what happens in the process_literal_param() method.
|
268
|
+
|
269
|
+
The SQLAlchemy docs _specifically_ say to not override this method.
|
270
|
+
|
271
|
+
It appears that any processing that happens from TypeEngine.process_literal_param happens _before_
|
272
|
+
and _in addition to_ whatever the class's impl.literal_processor() method does. The String.literal_processor()
|
273
|
+
method performs a string replacement that doubles any single-quote in the contained string. This raises a syntax
|
274
|
+
error in Databricks. And it's not necessary because ParamEscaper() already implements all the escaping we need.
|
275
|
+
|
276
|
+
We should consider opening an issue on the SQLAlchemy project to see if I'm using it wrong.
|
277
|
+
|
278
|
+
See type_api.py::TypeEngine.literal_processor:
|
279
|
+
|
280
|
+
```python
|
281
|
+
def process(value: Any) -> str:
|
282
|
+
return fixed_impl_processor(
|
283
|
+
fixed_process_literal_param(value, dialect)
|
284
|
+
)
|
285
|
+
```
|
286
|
+
|
287
|
+
That call to fixed_impl_processor wraps the result of fixed_process_literal_param (which is the
|
288
|
+
process_literal_param defined in our Databricks dialect)
|
289
|
+
|
290
|
+
https://docs.sqlalchemy.org/en/20/core/custom_types.html#sqlalchemy.types.TypeDecorator.literal_processor
|
291
|
+
"""
|
292
|
+
|
293
|
+
def process(value):
|
294
|
+
"""This is a copy of the default String.literal_processor() method but stripping away
|
295
|
+
its double-escaping behaviour for single-quotes.
|
296
|
+
"""
|
297
|
+
|
298
|
+
_step1 = self.process_literal_param(value, dialect="databricks")
|
299
|
+
if dialect.identifier_preparer._double_percents:
|
300
|
+
_step2 = _step1.replace("%", "%%")
|
301
|
+
else:
|
302
|
+
_step2 = _step1
|
303
|
+
|
304
|
+
return "%s" % _step2
|
305
|
+
|
306
|
+
return process
|
307
|
+
|
308
|
+
|
309
|
+
class TINYINT(sqlalchemy.types.TypeDecorator):
|
310
|
+
"""Represents 1-byte signed integers
|
311
|
+
|
312
|
+
Acts like a sqlalchemy SmallInteger() in Python but writes to a TINYINT field in Databricks
|
313
|
+
|
314
|
+
https://docs.databricks.com/en/sql/language-manual/data-types/tinyint-type.html
|
315
|
+
"""
|
316
|
+
|
317
|
+
impl = sqlalchemy.types.SmallInteger
|
318
|
+
cache_ok = True
|
319
|
+
|
320
|
+
|
321
|
+
@compiles(TINYINT, "databricks")
|
322
|
+
def compile_tinyint(type_, compiler, **kw):
|
323
|
+
return "TINYINT"
|