snowflake-sqlalchemy 1.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/sqlalchemy/__init__.py +116 -0
- snowflake/sqlalchemy/_constants.py +12 -0
- snowflake/sqlalchemy/base.py +1065 -0
- snowflake/sqlalchemy/custom_commands.py +621 -0
- snowflake/sqlalchemy/custom_types.py +105 -0
- snowflake/sqlalchemy/provision.py +12 -0
- snowflake/sqlalchemy/requirements.py +297 -0
- snowflake/sqlalchemy/snowdialect.py +911 -0
- snowflake/sqlalchemy/util.py +336 -0
- snowflake/sqlalchemy/version.py +6 -0
- snowflake_sqlalchemy-1.5.2.dist-info/METADATA +503 -0
- snowflake_sqlalchemy-1.5.2.dist-info/RECORD +15 -0
- snowflake_sqlalchemy-1.5.2.dist-info/WHEEL +4 -0
- snowflake_sqlalchemy-1.5.2.dist-info/entry_points.txt +2 -0
- snowflake_sqlalchemy-1.5.2.dist-info/licenses/LICENSE.txt +202 -0
|
@@ -0,0 +1,911 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
import operator
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from functools import reduce
|
|
8
|
+
from urllib.parse import unquote_plus
|
|
9
|
+
|
|
10
|
+
import sqlalchemy.types as sqltypes
|
|
11
|
+
from sqlalchemy import event as sa_vnt
|
|
12
|
+
from sqlalchemy import exc as sa_exc
|
|
13
|
+
from sqlalchemy import util as sa_util
|
|
14
|
+
from sqlalchemy.engine import URL, default, reflection
|
|
15
|
+
from sqlalchemy.schema import Table
|
|
16
|
+
from sqlalchemy.sql import text
|
|
17
|
+
from sqlalchemy.sql.elements import quoted_name
|
|
18
|
+
from sqlalchemy.sql.sqltypes import String
|
|
19
|
+
from sqlalchemy.types import (
|
|
20
|
+
BIGINT,
|
|
21
|
+
BINARY,
|
|
22
|
+
BOOLEAN,
|
|
23
|
+
CHAR,
|
|
24
|
+
DATE,
|
|
25
|
+
DATETIME,
|
|
26
|
+
DECIMAL,
|
|
27
|
+
FLOAT,
|
|
28
|
+
INTEGER,
|
|
29
|
+
REAL,
|
|
30
|
+
SMALLINT,
|
|
31
|
+
TIME,
|
|
32
|
+
TIMESTAMP,
|
|
33
|
+
VARCHAR,
|
|
34
|
+
Date,
|
|
35
|
+
DateTime,
|
|
36
|
+
Float,
|
|
37
|
+
Time,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
from snowflake.connector import errors as sf_errors
|
|
41
|
+
from snowflake.connector.connection import DEFAULT_CONFIGURATION
|
|
42
|
+
from snowflake.connector.constants import UTF8
|
|
43
|
+
|
|
44
|
+
from .base import (
|
|
45
|
+
SnowflakeCompiler,
|
|
46
|
+
SnowflakeDDLCompiler,
|
|
47
|
+
SnowflakeExecutionContext,
|
|
48
|
+
SnowflakeIdentifierPreparer,
|
|
49
|
+
SnowflakeTypeCompiler,
|
|
50
|
+
)
|
|
51
|
+
from .custom_types import (
|
|
52
|
+
_CUSTOM_DECIMAL,
|
|
53
|
+
ARRAY,
|
|
54
|
+
GEOGRAPHY,
|
|
55
|
+
GEOMETRY,
|
|
56
|
+
OBJECT,
|
|
57
|
+
TIMESTAMP_LTZ,
|
|
58
|
+
TIMESTAMP_NTZ,
|
|
59
|
+
TIMESTAMP_TZ,
|
|
60
|
+
VARIANT,
|
|
61
|
+
_CUSTOM_Date,
|
|
62
|
+
_CUSTOM_DateTime,
|
|
63
|
+
_CUSTOM_Float,
|
|
64
|
+
_CUSTOM_Time,
|
|
65
|
+
)
|
|
66
|
+
from .util import _update_connection_application_name, parse_url_boolean
|
|
67
|
+
|
|
68
|
+
colspecs = {
|
|
69
|
+
Date: _CUSTOM_Date,
|
|
70
|
+
DateTime: _CUSTOM_DateTime,
|
|
71
|
+
Time: _CUSTOM_Time,
|
|
72
|
+
Float: _CUSTOM_Float,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
ischema_names = {
|
|
76
|
+
"BIGINT": BIGINT,
|
|
77
|
+
"BINARY": BINARY,
|
|
78
|
+
# 'BIT': BIT,
|
|
79
|
+
"BOOLEAN": BOOLEAN,
|
|
80
|
+
"CHAR": CHAR,
|
|
81
|
+
"CHARACTER": CHAR,
|
|
82
|
+
"DATE": DATE,
|
|
83
|
+
"DATETIME": DATETIME,
|
|
84
|
+
"DEC": DECIMAL,
|
|
85
|
+
"DECIMAL": DECIMAL,
|
|
86
|
+
"DOUBLE": FLOAT,
|
|
87
|
+
"FIXED": DECIMAL,
|
|
88
|
+
"FLOAT": FLOAT,
|
|
89
|
+
"INT": INTEGER,
|
|
90
|
+
"INTEGER": INTEGER,
|
|
91
|
+
"NUMBER": _CUSTOM_DECIMAL,
|
|
92
|
+
# 'OBJECT': ?
|
|
93
|
+
"REAL": REAL,
|
|
94
|
+
"BYTEINT": SMALLINT,
|
|
95
|
+
"SMALLINT": SMALLINT,
|
|
96
|
+
"STRING": VARCHAR,
|
|
97
|
+
"TEXT": VARCHAR,
|
|
98
|
+
"TIME": TIME,
|
|
99
|
+
"TIMESTAMP": TIMESTAMP,
|
|
100
|
+
"TIMESTAMP_TZ": TIMESTAMP_TZ,
|
|
101
|
+
"TIMESTAMP_LTZ": TIMESTAMP_LTZ,
|
|
102
|
+
"TIMESTAMP_NTZ": TIMESTAMP_NTZ,
|
|
103
|
+
"TINYINT": SMALLINT,
|
|
104
|
+
"VARBINARY": BINARY,
|
|
105
|
+
"VARCHAR": VARCHAR,
|
|
106
|
+
"VARIANT": VARIANT,
|
|
107
|
+
"OBJECT": OBJECT,
|
|
108
|
+
"ARRAY": ARRAY,
|
|
109
|
+
"GEOGRAPHY": GEOGRAPHY,
|
|
110
|
+
"GEOMETRY": GEOMETRY,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
_ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class SnowflakeDialect(default.DefaultDialect):
|
|
117
|
+
name = "snowflake"
|
|
118
|
+
driver = "snowflake"
|
|
119
|
+
max_identifier_length = 255
|
|
120
|
+
cte_follows_insert = True
|
|
121
|
+
|
|
122
|
+
# TODO: support SQL caching, for more info see: https://docs.sqlalchemy.org/en/14/core/connections.html#caching-for-third-party-dialects
|
|
123
|
+
supports_statement_cache = False
|
|
124
|
+
|
|
125
|
+
encoding = UTF8
|
|
126
|
+
default_paramstyle = "pyformat"
|
|
127
|
+
colspecs = colspecs
|
|
128
|
+
ischema_names = ischema_names
|
|
129
|
+
|
|
130
|
+
# all str types must be converted in Unicode
|
|
131
|
+
convert_unicode = True
|
|
132
|
+
|
|
133
|
+
# Indicate whether the DB-API can receive SQL statements as Python
|
|
134
|
+
# unicode strings
|
|
135
|
+
supports_unicode_statements = True
|
|
136
|
+
supports_unicode_binds = True
|
|
137
|
+
returns_unicode_strings = String.RETURNS_UNICODE
|
|
138
|
+
description_encoding = None
|
|
139
|
+
|
|
140
|
+
# No lastrowid support. See SNOW-11155
|
|
141
|
+
postfetch_lastrowid = False
|
|
142
|
+
|
|
143
|
+
# Indicate whether the dialect properly implements rowcount for
|
|
144
|
+
# ``UPDATE`` and ``DELETE`` statements.
|
|
145
|
+
supports_sane_rowcount = True
|
|
146
|
+
|
|
147
|
+
# Indicate whether the dialect properly implements rowcount for
|
|
148
|
+
# ``UPDATE`` and ``DELETE`` statements when executed via
|
|
149
|
+
# executemany.
|
|
150
|
+
supports_sane_multi_rowcount = True
|
|
151
|
+
|
|
152
|
+
# NUMERIC type returns decimal.Decimal
|
|
153
|
+
supports_native_decimal = True
|
|
154
|
+
|
|
155
|
+
# The dialect supports a native boolean construct.
|
|
156
|
+
# This will prevent types.Boolean from generating a CHECK
|
|
157
|
+
# constraint when that type is used.
|
|
158
|
+
supports_native_boolean = True
|
|
159
|
+
|
|
160
|
+
# The dialect supports ``ALTER TABLE``.
|
|
161
|
+
supports_alter = True
|
|
162
|
+
|
|
163
|
+
# The dialect supports CREATE SEQUENCE or similar.
|
|
164
|
+
supports_sequences = True
|
|
165
|
+
|
|
166
|
+
# The dialect supports a native ENUM construct.
|
|
167
|
+
supports_native_enum = False
|
|
168
|
+
|
|
169
|
+
# The dialect supports inserting multiple rows at once.
|
|
170
|
+
supports_multivalues_insert = True
|
|
171
|
+
|
|
172
|
+
# The dialect supports comments
|
|
173
|
+
supports_comments = True
|
|
174
|
+
|
|
175
|
+
preparer = SnowflakeIdentifierPreparer
|
|
176
|
+
ddl_compiler = SnowflakeDDLCompiler
|
|
177
|
+
type_compiler = SnowflakeTypeCompiler
|
|
178
|
+
statement_compiler = SnowflakeCompiler
|
|
179
|
+
execution_ctx_cls = SnowflakeExecutionContext
|
|
180
|
+
|
|
181
|
+
# indicates symbol names are UPPERCASEd if they are case insensitive
|
|
182
|
+
# within the database. If this is True, the methods normalize_name()
|
|
183
|
+
# and denormalize_name() must be provided.
|
|
184
|
+
requires_name_normalize = True
|
|
185
|
+
|
|
186
|
+
multivalues_inserts = True
|
|
187
|
+
|
|
188
|
+
supports_schemas = True
|
|
189
|
+
|
|
190
|
+
sequences_optional = True
|
|
191
|
+
|
|
192
|
+
supports_is_distinct_from = True
|
|
193
|
+
|
|
194
|
+
supports_identity_columns = True
|
|
195
|
+
|
|
196
|
+
@classmethod
|
|
197
|
+
def dbapi(cls):
|
|
198
|
+
from snowflake import connector
|
|
199
|
+
|
|
200
|
+
return connector
|
|
201
|
+
|
|
202
|
+
def create_connect_args(self, url: URL):
|
|
203
|
+
opts = url.translate_connect_args(username="user")
|
|
204
|
+
if "database" in opts:
|
|
205
|
+
name_spaces = [unquote_plus(e) for e in opts["database"].split("/")]
|
|
206
|
+
if len(name_spaces) == 1:
|
|
207
|
+
pass
|
|
208
|
+
elif len(name_spaces) == 2:
|
|
209
|
+
opts["database"] = name_spaces[0]
|
|
210
|
+
opts["schema"] = name_spaces[1]
|
|
211
|
+
else:
|
|
212
|
+
raise sa_exc.ArgumentError(
|
|
213
|
+
f"Invalid name space is specified: {opts['database']}"
|
|
214
|
+
)
|
|
215
|
+
if (
|
|
216
|
+
"host" in opts
|
|
217
|
+
and ".snowflakecomputing.com" not in opts["host"]
|
|
218
|
+
and not opts.get("port")
|
|
219
|
+
):
|
|
220
|
+
opts["account"] = opts["host"]
|
|
221
|
+
if "." in opts["account"]:
|
|
222
|
+
# remove region subdomain
|
|
223
|
+
opts["account"] = opts["account"][0 : opts["account"].find(".")]
|
|
224
|
+
# remove external ID
|
|
225
|
+
opts["account"] = opts["account"].split("-")[0]
|
|
226
|
+
opts["host"] = opts["host"] + ".snowflakecomputing.com"
|
|
227
|
+
opts["port"] = "443"
|
|
228
|
+
opts["autocommit"] = False # autocommit is disabled by default
|
|
229
|
+
|
|
230
|
+
query = dict(**url.query) # make mutable
|
|
231
|
+
cache_column_metadata = query.pop("cache_column_metadata", None)
|
|
232
|
+
self._cache_column_metadata = (
|
|
233
|
+
parse_url_boolean(cache_column_metadata) if cache_column_metadata else False
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# URL sets the query parameter values as strings, we need to cast to expected types when necessary
|
|
237
|
+
for name, value in query.items():
|
|
238
|
+
maybe_type_configuration = DEFAULT_CONFIGURATION.get(name)
|
|
239
|
+
if (
|
|
240
|
+
not maybe_type_configuration
|
|
241
|
+
): # if the parameter is not found in the type mapping, pass it through as a string
|
|
242
|
+
opts[name] = value
|
|
243
|
+
continue
|
|
244
|
+
|
|
245
|
+
(_, expected_type) = maybe_type_configuration
|
|
246
|
+
if not isinstance(expected_type, tuple):
|
|
247
|
+
expected_type = (expected_type,)
|
|
248
|
+
|
|
249
|
+
if isinstance(
|
|
250
|
+
value, expected_type
|
|
251
|
+
): # if the expected type is str, pass it through as a string
|
|
252
|
+
opts[name] = value
|
|
253
|
+
|
|
254
|
+
elif (
|
|
255
|
+
bool in expected_type
|
|
256
|
+
): # if the expected type is bool, parse it and pass as a boolean
|
|
257
|
+
opts[name] = parse_url_boolean(value)
|
|
258
|
+
else:
|
|
259
|
+
# TODO: other types like int are stil passed through as string
|
|
260
|
+
# https://github.com/snowflakedb/snowflake-sqlalchemy/issues/447
|
|
261
|
+
opts[name] = value
|
|
262
|
+
|
|
263
|
+
return ([], opts)
|
|
264
|
+
|
|
265
|
+
def has_table(self, connection, table_name, schema=None):
|
|
266
|
+
"""
|
|
267
|
+
Checks if the table exists
|
|
268
|
+
"""
|
|
269
|
+
return self._has_object(connection, "TABLE", table_name, schema)
|
|
270
|
+
|
|
271
|
+
def has_sequence(self, connection, sequence_name, schema=None):
|
|
272
|
+
"""
|
|
273
|
+
Checks if the sequence exists
|
|
274
|
+
"""
|
|
275
|
+
return self._has_object(connection, "SEQUENCE", sequence_name, schema)
|
|
276
|
+
|
|
277
|
+
def _has_object(self, connection, object_type, object_name, schema=None):
|
|
278
|
+
|
|
279
|
+
full_name = self._denormalize_quote_join(schema, object_name)
|
|
280
|
+
try:
|
|
281
|
+
results = connection.execute(
|
|
282
|
+
text(f"DESC {object_type} /* sqlalchemy:_has_object */ {full_name}")
|
|
283
|
+
)
|
|
284
|
+
row = results.fetchone()
|
|
285
|
+
have = row is not None
|
|
286
|
+
return have
|
|
287
|
+
except sa_exc.DBAPIError as e:
|
|
288
|
+
if e.orig.__class__ == sf_errors.ProgrammingError:
|
|
289
|
+
return False
|
|
290
|
+
raise
|
|
291
|
+
|
|
292
|
+
def normalize_name(self, name):
|
|
293
|
+
if name is None:
|
|
294
|
+
return None
|
|
295
|
+
if name == "":
|
|
296
|
+
return ""
|
|
297
|
+
if name.upper() == name and not self.identifier_preparer._requires_quotes(
|
|
298
|
+
name.lower()
|
|
299
|
+
):
|
|
300
|
+
return name.lower()
|
|
301
|
+
elif name.lower() == name:
|
|
302
|
+
return quoted_name(name, quote=True)
|
|
303
|
+
else:
|
|
304
|
+
return name
|
|
305
|
+
|
|
306
|
+
def denormalize_name(self, name):
|
|
307
|
+
if name is None:
|
|
308
|
+
return None
|
|
309
|
+
if name == "":
|
|
310
|
+
return ""
|
|
311
|
+
elif name.lower() == name and not self.identifier_preparer._requires_quotes(
|
|
312
|
+
name.lower()
|
|
313
|
+
):
|
|
314
|
+
name = name.upper()
|
|
315
|
+
return name
|
|
316
|
+
|
|
317
|
+
def _denormalize_quote_join(self, *idents):
|
|
318
|
+
ip = self.identifier_preparer
|
|
319
|
+
split_idents = reduce(
|
|
320
|
+
operator.add,
|
|
321
|
+
[ip._split_schema_by_dot(ids) for ids in idents if ids is not None],
|
|
322
|
+
)
|
|
323
|
+
return ".".join(ip._quote_free_identifiers(*split_idents))
|
|
324
|
+
|
|
325
|
+
@reflection.cache
|
|
326
|
+
def _current_database_schema(self, connection, **kw):
|
|
327
|
+
res = connection.exec_driver_sql(
|
|
328
|
+
"select current_database(), current_schema();"
|
|
329
|
+
).fetchone()
|
|
330
|
+
return (
|
|
331
|
+
self.normalize_name(res[0]),
|
|
332
|
+
self.normalize_name(res[1]),
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
def _get_default_schema_name(self, connection):
|
|
336
|
+
# NOTE: no cache object is passed here
|
|
337
|
+
_, current_schema = self._current_database_schema(connection)
|
|
338
|
+
return current_schema
|
|
339
|
+
|
|
340
|
+
@staticmethod
|
|
341
|
+
def _map_name_to_idx(result):
|
|
342
|
+
name_to_idx = {}
|
|
343
|
+
for idx, col in enumerate(result.cursor.description):
|
|
344
|
+
name_to_idx[col[0]] = idx
|
|
345
|
+
return name_to_idx
|
|
346
|
+
|
|
347
|
+
@reflection.cache
|
|
348
|
+
def get_indexes(self, connection, table_name, schema=None, **kw):
|
|
349
|
+
"""
|
|
350
|
+
Gets all indexes
|
|
351
|
+
"""
|
|
352
|
+
# no index is supported by Snowflake
|
|
353
|
+
return []
|
|
354
|
+
|
|
355
|
+
@reflection.cache
|
|
356
|
+
def get_check_constraints(self, connection, table_name, schema, **kw):
|
|
357
|
+
# check constraints are not supported by Snowflake
|
|
358
|
+
return []
|
|
359
|
+
|
|
360
|
+
@reflection.cache
|
|
361
|
+
def _get_schema_primary_keys(self, connection, schema, **kw):
|
|
362
|
+
result = connection.execute(
|
|
363
|
+
text(
|
|
364
|
+
f"SHOW /* sqlalchemy:_get_schema_primary_keys */PRIMARY KEYS IN SCHEMA {schema}"
|
|
365
|
+
)
|
|
366
|
+
)
|
|
367
|
+
ans = {}
|
|
368
|
+
for row in result:
|
|
369
|
+
table_name = self.normalize_name(row._mapping["table_name"])
|
|
370
|
+
if table_name not in ans:
|
|
371
|
+
ans[table_name] = {
|
|
372
|
+
"constrained_columns": [],
|
|
373
|
+
"name": self.normalize_name(row._mapping["constraint_name"]),
|
|
374
|
+
}
|
|
375
|
+
ans[table_name]["constrained_columns"].append(
|
|
376
|
+
self.normalize_name(row._mapping["column_name"])
|
|
377
|
+
)
|
|
378
|
+
return ans
|
|
379
|
+
|
|
380
|
+
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
|
|
381
|
+
schema = schema or self.default_schema_name
|
|
382
|
+
current_database, current_schema = self._current_database_schema(
|
|
383
|
+
connection, **kw
|
|
384
|
+
)
|
|
385
|
+
full_schema_name = self._denormalize_quote_join(
|
|
386
|
+
current_database, schema if schema else current_schema
|
|
387
|
+
)
|
|
388
|
+
return self._get_schema_primary_keys(
|
|
389
|
+
connection, self.denormalize_name(full_schema_name), **kw
|
|
390
|
+
).get(table_name, {"constrained_columns": [], "name": None})
|
|
391
|
+
|
|
392
|
+
@reflection.cache
|
|
393
|
+
def _get_schema_unique_constraints(self, connection, schema, **kw):
|
|
394
|
+
result = connection.execute(
|
|
395
|
+
text(
|
|
396
|
+
f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN SCHEMA {schema}"
|
|
397
|
+
)
|
|
398
|
+
)
|
|
399
|
+
unique_constraints = {}
|
|
400
|
+
for row in result:
|
|
401
|
+
name = self.normalize_name(row._mapping["constraint_name"])
|
|
402
|
+
if name not in unique_constraints:
|
|
403
|
+
unique_constraints[name] = {
|
|
404
|
+
"column_names": [self.normalize_name(row._mapping["column_name"])],
|
|
405
|
+
"name": name,
|
|
406
|
+
"table_name": self.normalize_name(row._mapping["table_name"]),
|
|
407
|
+
}
|
|
408
|
+
else:
|
|
409
|
+
unique_constraints[name]["column_names"].append(
|
|
410
|
+
self.normalize_name(row._mapping["column_name"])
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
ans = defaultdict(list)
|
|
414
|
+
for constraint in unique_constraints.values():
|
|
415
|
+
table_name = constraint.pop("table_name")
|
|
416
|
+
ans[table_name].append(constraint)
|
|
417
|
+
return ans
|
|
418
|
+
|
|
419
|
+
def get_unique_constraints(self, connection, table_name, schema, **kw):
|
|
420
|
+
schema = schema or self.default_schema_name
|
|
421
|
+
current_database, current_schema = self._current_database_schema(
|
|
422
|
+
connection, **kw
|
|
423
|
+
)
|
|
424
|
+
full_schema_name = self._denormalize_quote_join(
|
|
425
|
+
current_database, schema if schema else current_schema
|
|
426
|
+
)
|
|
427
|
+
return self._get_schema_unique_constraints(
|
|
428
|
+
connection, self.denormalize_name(full_schema_name), **kw
|
|
429
|
+
).get(table_name, [])
|
|
430
|
+
|
|
431
|
+
@reflection.cache
|
|
432
|
+
def _get_schema_foreign_keys(self, connection, schema, **kw):
|
|
433
|
+
_, current_schema = self._current_database_schema(connection, **kw)
|
|
434
|
+
result = connection.execute(
|
|
435
|
+
text(
|
|
436
|
+
f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN SCHEMA {schema}"
|
|
437
|
+
)
|
|
438
|
+
)
|
|
439
|
+
foreign_key_map = {}
|
|
440
|
+
for row in result:
|
|
441
|
+
name = self.normalize_name(row._mapping["fk_name"])
|
|
442
|
+
if name not in foreign_key_map:
|
|
443
|
+
referred_schema = self.normalize_name(row._mapping["pk_schema_name"])
|
|
444
|
+
foreign_key_map[name] = {
|
|
445
|
+
"constrained_columns": [
|
|
446
|
+
self.normalize_name(row._mapping["fk_column_name"])
|
|
447
|
+
],
|
|
448
|
+
# referred schema should be None in context where it doesn't need to be specified
|
|
449
|
+
# https://docs.sqlalchemy.org/en/14/core/reflection.html#reflection-schema-qualified-interaction
|
|
450
|
+
"referred_schema": (
|
|
451
|
+
referred_schema
|
|
452
|
+
if referred_schema
|
|
453
|
+
not in (self.default_schema_name, current_schema)
|
|
454
|
+
else None
|
|
455
|
+
),
|
|
456
|
+
"referred_table": self.normalize_name(
|
|
457
|
+
row._mapping["pk_table_name"]
|
|
458
|
+
),
|
|
459
|
+
"referred_columns": [
|
|
460
|
+
self.normalize_name(row._mapping["pk_column_name"])
|
|
461
|
+
],
|
|
462
|
+
"name": name,
|
|
463
|
+
"table_name": self.normalize_name(row._mapping["fk_table_name"]),
|
|
464
|
+
}
|
|
465
|
+
options = {}
|
|
466
|
+
if self.normalize_name(row._mapping["delete_rule"]) != "NO ACTION":
|
|
467
|
+
options["ondelete"] = self.normalize_name(
|
|
468
|
+
row._mapping["delete_rule"]
|
|
469
|
+
)
|
|
470
|
+
if self.normalize_name(row._mapping["update_rule"]) != "NO ACTION":
|
|
471
|
+
options["onupdate"] = self.normalize_name(
|
|
472
|
+
row._mapping["update_rule"]
|
|
473
|
+
)
|
|
474
|
+
foreign_key_map[name]["options"] = options
|
|
475
|
+
else:
|
|
476
|
+
foreign_key_map[name]["constrained_columns"].append(
|
|
477
|
+
self.normalize_name(row._mapping["fk_column_name"])
|
|
478
|
+
)
|
|
479
|
+
foreign_key_map[name]["referred_columns"].append(
|
|
480
|
+
self.normalize_name(row._mapping["pk_column_name"])
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
ans = {}
|
|
484
|
+
|
|
485
|
+
for _, v in foreign_key_map.items():
|
|
486
|
+
if v["table_name"] not in ans:
|
|
487
|
+
ans[v["table_name"]] = []
|
|
488
|
+
ans[v["table_name"]].append(
|
|
489
|
+
{k2: v2 for k2, v2 in v.items() if k2 != "table_name"}
|
|
490
|
+
)
|
|
491
|
+
return ans
|
|
492
|
+
|
|
493
|
+
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
|
|
494
|
+
"""
|
|
495
|
+
Gets all foreign keys for a table
|
|
496
|
+
"""
|
|
497
|
+
schema = schema or self.default_schema_name
|
|
498
|
+
current_database, current_schema = self._current_database_schema(
|
|
499
|
+
connection, **kw
|
|
500
|
+
)
|
|
501
|
+
full_schema_name = self._denormalize_quote_join(
|
|
502
|
+
current_database, schema if schema else current_schema
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
foreign_key_map = self._get_schema_foreign_keys(
|
|
506
|
+
connection, self.denormalize_name(full_schema_name), **kw
|
|
507
|
+
)
|
|
508
|
+
return foreign_key_map.get(table_name, [])
|
|
509
|
+
|
|
510
|
+
@reflection.cache
|
|
511
|
+
def _get_schema_columns(self, connection, schema, **kw):
|
|
512
|
+
"""Get all columns in the schema, if we hit 'Information schema query returned too much data' problem return
|
|
513
|
+
None, as it is cacheable and is an unexpected return type for this function"""
|
|
514
|
+
ans = {}
|
|
515
|
+
current_database, _ = self._current_database_schema(connection, **kw)
|
|
516
|
+
full_schema_name = self._denormalize_quote_join(current_database, schema)
|
|
517
|
+
try:
|
|
518
|
+
schema_primary_keys = self._get_schema_primary_keys(
|
|
519
|
+
connection, full_schema_name, **kw
|
|
520
|
+
)
|
|
521
|
+
result = connection.execute(
|
|
522
|
+
text(
|
|
523
|
+
"""
|
|
524
|
+
SELECT /* sqlalchemy:_get_schema_columns */
|
|
525
|
+
ic.table_name,
|
|
526
|
+
ic.column_name,
|
|
527
|
+
ic.data_type,
|
|
528
|
+
ic.character_maximum_length,
|
|
529
|
+
ic.numeric_precision,
|
|
530
|
+
ic.numeric_scale,
|
|
531
|
+
ic.is_nullable,
|
|
532
|
+
ic.column_default,
|
|
533
|
+
ic.is_identity,
|
|
534
|
+
ic.comment,
|
|
535
|
+
ic.identity_start,
|
|
536
|
+
ic.identity_increment
|
|
537
|
+
FROM information_schema.columns ic
|
|
538
|
+
WHERE ic.table_schema=:table_schema
|
|
539
|
+
ORDER BY ic.ordinal_position"""
|
|
540
|
+
),
|
|
541
|
+
{"table_schema": self.denormalize_name(schema)},
|
|
542
|
+
)
|
|
543
|
+
except sa_exc.ProgrammingError as pe:
|
|
544
|
+
if pe.orig.errno == 90030:
|
|
545
|
+
# This means that there are too many tables in the schema, we need to go more granular
|
|
546
|
+
return None # None triggers _get_table_columns while staying cacheable
|
|
547
|
+
raise
|
|
548
|
+
for (
|
|
549
|
+
table_name,
|
|
550
|
+
column_name,
|
|
551
|
+
coltype,
|
|
552
|
+
character_maximum_length,
|
|
553
|
+
numeric_precision,
|
|
554
|
+
numeric_scale,
|
|
555
|
+
is_nullable,
|
|
556
|
+
column_default,
|
|
557
|
+
is_identity,
|
|
558
|
+
comment,
|
|
559
|
+
identity_start,
|
|
560
|
+
identity_increment,
|
|
561
|
+
) in result:
|
|
562
|
+
table_name = self.normalize_name(table_name)
|
|
563
|
+
column_name = self.normalize_name(column_name)
|
|
564
|
+
if table_name not in ans:
|
|
565
|
+
ans[table_name] = list()
|
|
566
|
+
if column_name.startswith("sys_clustering_column"):
|
|
567
|
+
continue # ignoring clustering column
|
|
568
|
+
col_type = self.ischema_names.get(coltype, None)
|
|
569
|
+
col_type_kw = {}
|
|
570
|
+
if col_type is None:
|
|
571
|
+
sa_util.warn(
|
|
572
|
+
f"Did not recognize type '{coltype}' of column '{column_name}'"
|
|
573
|
+
)
|
|
574
|
+
col_type = sqltypes.NULLTYPE
|
|
575
|
+
else:
|
|
576
|
+
if issubclass(col_type, FLOAT):
|
|
577
|
+
col_type_kw["precision"] = numeric_precision
|
|
578
|
+
col_type_kw["decimal_return_scale"] = numeric_scale
|
|
579
|
+
elif issubclass(col_type, sqltypes.Numeric):
|
|
580
|
+
col_type_kw["precision"] = numeric_precision
|
|
581
|
+
col_type_kw["scale"] = numeric_scale
|
|
582
|
+
elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)):
|
|
583
|
+
col_type_kw["length"] = character_maximum_length
|
|
584
|
+
|
|
585
|
+
type_instance = col_type(**col_type_kw)
|
|
586
|
+
|
|
587
|
+
current_table_pks = schema_primary_keys.get(table_name)
|
|
588
|
+
|
|
589
|
+
ans[table_name].append(
|
|
590
|
+
{
|
|
591
|
+
"name": column_name,
|
|
592
|
+
"type": type_instance,
|
|
593
|
+
"nullable": is_nullable == "YES",
|
|
594
|
+
"default": column_default,
|
|
595
|
+
"autoincrement": is_identity == "YES",
|
|
596
|
+
"comment": comment,
|
|
597
|
+
"primary_key": (
|
|
598
|
+
(
|
|
599
|
+
column_name
|
|
600
|
+
in schema_primary_keys[table_name]["constrained_columns"]
|
|
601
|
+
)
|
|
602
|
+
if current_table_pks
|
|
603
|
+
else False
|
|
604
|
+
),
|
|
605
|
+
}
|
|
606
|
+
)
|
|
607
|
+
if is_identity == "YES":
|
|
608
|
+
ans[table_name][-1]["identity"] = {
|
|
609
|
+
"start": identity_start,
|
|
610
|
+
"increment": identity_increment,
|
|
611
|
+
}
|
|
612
|
+
return ans
|
|
613
|
+
|
|
614
|
+
@reflection.cache
|
|
615
|
+
def _get_table_columns(self, connection, table_name, schema=None, **kw):
|
|
616
|
+
"""Get all columns in a table in a schema"""
|
|
617
|
+
ans = []
|
|
618
|
+
current_database, _ = self._current_database_schema(connection, **kw)
|
|
619
|
+
full_schema_name = self._denormalize_quote_join(current_database, schema)
|
|
620
|
+
schema_primary_keys = self._get_schema_primary_keys(
|
|
621
|
+
connection, full_schema_name, **kw
|
|
622
|
+
)
|
|
623
|
+
result = connection.execute(
|
|
624
|
+
text(
|
|
625
|
+
"""
|
|
626
|
+
SELECT /* sqlalchemy:get_table_columns */
|
|
627
|
+
ic.table_name,
|
|
628
|
+
ic.column_name,
|
|
629
|
+
ic.data_type,
|
|
630
|
+
ic.character_maximum_length,
|
|
631
|
+
ic.numeric_precision,
|
|
632
|
+
ic.numeric_scale,
|
|
633
|
+
ic.is_nullable,
|
|
634
|
+
ic.column_default,
|
|
635
|
+
ic.is_identity,
|
|
636
|
+
ic.comment
|
|
637
|
+
FROM information_schema.columns ic
|
|
638
|
+
WHERE ic.table_schema=:table_schema
|
|
639
|
+
AND ic.table_name=:table_name
|
|
640
|
+
ORDER BY ic.ordinal_position"""
|
|
641
|
+
),
|
|
642
|
+
{
|
|
643
|
+
"table_schema": self.denormalize_name(schema),
|
|
644
|
+
"table_name": self.denormalize_name(table_name),
|
|
645
|
+
},
|
|
646
|
+
)
|
|
647
|
+
for (
|
|
648
|
+
table_name,
|
|
649
|
+
column_name,
|
|
650
|
+
coltype,
|
|
651
|
+
character_maximum_length,
|
|
652
|
+
numeric_precision,
|
|
653
|
+
numeric_scale,
|
|
654
|
+
is_nullable,
|
|
655
|
+
column_default,
|
|
656
|
+
is_identity,
|
|
657
|
+
comment,
|
|
658
|
+
) in result:
|
|
659
|
+
table_name = self.normalize_name(table_name)
|
|
660
|
+
column_name = self.normalize_name(column_name)
|
|
661
|
+
if column_name.startswith("sys_clustering_column"):
|
|
662
|
+
continue # ignoring clustering column
|
|
663
|
+
col_type = self.ischema_names.get(coltype, None)
|
|
664
|
+
col_type_kw = {}
|
|
665
|
+
if col_type is None:
|
|
666
|
+
sa_util.warn(
|
|
667
|
+
f"Did not recognize type '{coltype}' of column '{column_name}'"
|
|
668
|
+
)
|
|
669
|
+
col_type = sqltypes.NULLTYPE
|
|
670
|
+
else:
|
|
671
|
+
if issubclass(col_type, FLOAT):
|
|
672
|
+
col_type_kw["precision"] = numeric_precision
|
|
673
|
+
col_type_kw["decimal_return_scale"] = numeric_scale
|
|
674
|
+
elif issubclass(col_type, sqltypes.Numeric):
|
|
675
|
+
col_type_kw["precision"] = numeric_precision
|
|
676
|
+
col_type_kw["scale"] = numeric_scale
|
|
677
|
+
elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)):
|
|
678
|
+
col_type_kw["length"] = character_maximum_length
|
|
679
|
+
|
|
680
|
+
type_instance = col_type(**col_type_kw)
|
|
681
|
+
|
|
682
|
+
current_table_pks = schema_primary_keys.get(table_name)
|
|
683
|
+
|
|
684
|
+
ans.append(
|
|
685
|
+
{
|
|
686
|
+
"name": column_name,
|
|
687
|
+
"type": type_instance,
|
|
688
|
+
"nullable": is_nullable == "YES",
|
|
689
|
+
"default": column_default,
|
|
690
|
+
"autoincrement": is_identity == "YES",
|
|
691
|
+
"comment": comment if comment != "" else None,
|
|
692
|
+
"primary_key": (
|
|
693
|
+
(
|
|
694
|
+
column_name
|
|
695
|
+
in schema_primary_keys[table_name]["constrained_columns"]
|
|
696
|
+
)
|
|
697
|
+
if current_table_pks
|
|
698
|
+
else False
|
|
699
|
+
),
|
|
700
|
+
}
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
# If we didn't find any columns for the table, the table doesn't exist.
|
|
704
|
+
if len(ans) == 0:
|
|
705
|
+
raise sa_exc.NoSuchTableError()
|
|
706
|
+
return ans
|
|
707
|
+
|
|
708
|
+
def get_columns(self, connection, table_name, schema=None, **kw):
|
|
709
|
+
"""
|
|
710
|
+
Gets all column info given the table info
|
|
711
|
+
"""
|
|
712
|
+
schema = schema or self.default_schema_name
|
|
713
|
+
if not schema:
|
|
714
|
+
_, schema = self._current_database_schema(connection, **kw)
|
|
715
|
+
|
|
716
|
+
schema_columns = self._get_schema_columns(connection, schema, **kw)
|
|
717
|
+
if schema_columns is None:
|
|
718
|
+
# Too many results, fall back to only query about single table
|
|
719
|
+
return self._get_table_columns(connection, table_name, schema, **kw)
|
|
720
|
+
normalized_table_name = self.normalize_name(table_name)
|
|
721
|
+
if normalized_table_name not in schema_columns:
|
|
722
|
+
raise sa_exc.NoSuchTableError()
|
|
723
|
+
return schema_columns[normalized_table_name]
|
|
724
|
+
|
|
725
|
+
@reflection.cache
|
|
726
|
+
def get_table_names(self, connection, schema=None, **kw):
|
|
727
|
+
"""
|
|
728
|
+
Gets all table names.
|
|
729
|
+
"""
|
|
730
|
+
schema = schema or self.default_schema_name
|
|
731
|
+
current_schema = schema
|
|
732
|
+
if schema:
|
|
733
|
+
cursor = connection.execute(
|
|
734
|
+
text(
|
|
735
|
+
f"SHOW /* sqlalchemy:get_table_names */ TABLES IN {self._denormalize_quote_join(schema)}"
|
|
736
|
+
)
|
|
737
|
+
)
|
|
738
|
+
else:
|
|
739
|
+
cursor = connection.execute(
|
|
740
|
+
text("SHOW /* sqlalchemy:get_table_names */ TABLES")
|
|
741
|
+
)
|
|
742
|
+
_, current_schema = self._current_database_schema(connection)
|
|
743
|
+
|
|
744
|
+
ret = [self.normalize_name(row[1]) for row in cursor]
|
|
745
|
+
|
|
746
|
+
return ret
|
|
747
|
+
|
|
748
|
+
@reflection.cache
|
|
749
|
+
def get_view_names(self, connection, schema=None, **kw):
|
|
750
|
+
"""
|
|
751
|
+
Gets all view names
|
|
752
|
+
"""
|
|
753
|
+
schema = schema or self.default_schema_name
|
|
754
|
+
if schema:
|
|
755
|
+
cursor = connection.execute(
|
|
756
|
+
text(
|
|
757
|
+
f"SHOW /* sqlalchemy:get_view_names */ VIEWS IN {self._denormalize_quote_join(schema)}"
|
|
758
|
+
)
|
|
759
|
+
)
|
|
760
|
+
else:
|
|
761
|
+
cursor = connection.execute(
|
|
762
|
+
text("SHOW /* sqlalchemy:get_view_names */ VIEWS")
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
return [self.normalize_name(row[1]) for row in cursor]
|
|
766
|
+
|
|
767
|
+
@reflection.cache
|
|
768
|
+
def get_view_definition(self, connection, view_name, schema=None, **kw):
|
|
769
|
+
"""
|
|
770
|
+
Gets the view definition
|
|
771
|
+
"""
|
|
772
|
+
schema = schema or self.default_schema_name
|
|
773
|
+
if schema:
|
|
774
|
+
cursor = connection.execute(
|
|
775
|
+
text(
|
|
776
|
+
f"SHOW /* sqlalchemy:get_view_definition */ VIEWS \
|
|
777
|
+
LIKE '{self._denormalize_quote_join(view_name)}' IN {self._denormalize_quote_join(schema)}"
|
|
778
|
+
)
|
|
779
|
+
)
|
|
780
|
+
else:
|
|
781
|
+
cursor = connection.execute(
|
|
782
|
+
text(
|
|
783
|
+
f"SHOW /* sqlalchemy:get_view_definition */ VIEWS \
|
|
784
|
+
LIKE '{self._denormalize_quote_join(view_name)}'"
|
|
785
|
+
)
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
n2i = self.__class__._map_name_to_idx(cursor)
|
|
789
|
+
try:
|
|
790
|
+
ret = cursor.fetchone()
|
|
791
|
+
if ret:
|
|
792
|
+
return ret[n2i["text"]]
|
|
793
|
+
except Exception:
|
|
794
|
+
pass
|
|
795
|
+
return None
|
|
796
|
+
|
|
797
|
+
def get_temp_table_names(self, connection, schema=None, **kw):
|
|
798
|
+
schema = schema or self.default_schema_name
|
|
799
|
+
if schema:
|
|
800
|
+
cursor = connection.execute(
|
|
801
|
+
text(
|
|
802
|
+
f"SHOW /* sqlalchemy:get_temp_table_names */ TABLES \
|
|
803
|
+
IN {self._denormalize_quote_join(schema)}"
|
|
804
|
+
)
|
|
805
|
+
)
|
|
806
|
+
else:
|
|
807
|
+
cursor = connection.execute(
|
|
808
|
+
text("SHOW /* sqlalchemy:get_temp_table_names */ TABLES")
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
ret = []
|
|
812
|
+
n2i = self.__class__._map_name_to_idx(cursor)
|
|
813
|
+
for row in cursor:
|
|
814
|
+
if row[n2i["kind"]] == "TEMPORARY":
|
|
815
|
+
ret.append(self.normalize_name(row[n2i["name"]]))
|
|
816
|
+
|
|
817
|
+
return ret
|
|
818
|
+
|
|
819
|
+
def get_schema_names(self, connection, **kw):
|
|
820
|
+
"""
|
|
821
|
+
Gets all schema names.
|
|
822
|
+
"""
|
|
823
|
+
cursor = connection.execute(
|
|
824
|
+
text("SHOW /* sqlalchemy:get_schema_names */ SCHEMAS")
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
return [self.normalize_name(row[1]) for row in cursor]
|
|
828
|
+
|
|
829
|
+
@reflection.cache
|
|
830
|
+
def get_sequence_names(self, connection, schema=None, **kw):
|
|
831
|
+
sql_command = "SHOW SEQUENCES {}".format(
|
|
832
|
+
f"IN SCHEMA {self.normalize_name(schema)}" if schema else ""
|
|
833
|
+
)
|
|
834
|
+
try:
|
|
835
|
+
cursor = connection.execute(text(sql_command))
|
|
836
|
+
return [self.normalize_name(row[0]) for row in cursor]
|
|
837
|
+
except sa_exc.ProgrammingError as pe:
|
|
838
|
+
if pe.orig.errno == 2003:
|
|
839
|
+
# Schema does not exist
|
|
840
|
+
return []
|
|
841
|
+
|
|
842
|
+
def _get_table_comment(self, connection, table_name, schema=None, **kw):
|
|
843
|
+
"""
|
|
844
|
+
Returns comment of table in a dictionary as described by SQLAlchemy spec.
|
|
845
|
+
"""
|
|
846
|
+
sql_command = (
|
|
847
|
+
"SHOW /* sqlalchemy:_get_table_comment */ "
|
|
848
|
+
"TABLES LIKE '{}'{}".format(
|
|
849
|
+
table_name,
|
|
850
|
+
f" IN SCHEMA {self.normalize_name(schema)}" if schema else "",
|
|
851
|
+
)
|
|
852
|
+
)
|
|
853
|
+
cursor = connection.execute(text(sql_command))
|
|
854
|
+
return cursor.fetchone()
|
|
855
|
+
|
|
856
|
+
def _get_view_comment(self, connection, table_name, schema=None, **kw):
|
|
857
|
+
"""
|
|
858
|
+
Returns comment of view in a dictionary as described by SQLAlchemy spec.
|
|
859
|
+
"""
|
|
860
|
+
sql_command = (
|
|
861
|
+
"SHOW /* sqlalchemy:_get_view_comment */ "
|
|
862
|
+
"VIEWS LIKE '{}'{}".format(
|
|
863
|
+
table_name,
|
|
864
|
+
f" IN SCHEMA {self.normalize_name(schema)}" if schema else "",
|
|
865
|
+
)
|
|
866
|
+
)
|
|
867
|
+
cursor = connection.execute(text(sql_command))
|
|
868
|
+
return cursor.fetchone()
|
|
869
|
+
|
|
870
|
+
def get_table_comment(self, connection, table_name, schema=None, **kw):
|
|
871
|
+
"""
|
|
872
|
+
Returns comment associated with a table (or view) in a dictionary as
|
|
873
|
+
SQLAlchemy expects. Note that since SQLAlchemy may not (in fact,
|
|
874
|
+
typically does not) know if this is a table or a view, we have to
|
|
875
|
+
handle both cases here.
|
|
876
|
+
"""
|
|
877
|
+
result = self._get_table_comment(connection, table_name, schema)
|
|
878
|
+
if result is None:
|
|
879
|
+
# the "table" being reflected is actually a view
|
|
880
|
+
result = self._get_view_comment(connection, table_name, schema)
|
|
881
|
+
|
|
882
|
+
return {
|
|
883
|
+
"text": (
|
|
884
|
+
result._mapping["comment"]
|
|
885
|
+
if result and result._mapping["comment"]
|
|
886
|
+
else None
|
|
887
|
+
)
|
|
888
|
+
}
|
|
889
|
+
|
|
890
|
+
def connect(self, *cargs, **cparams):
|
|
891
|
+
return (
|
|
892
|
+
super().connect(
|
|
893
|
+
*cargs,
|
|
894
|
+
**(
|
|
895
|
+
_update_connection_application_name(**cparams)
|
|
896
|
+
if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
|
|
897
|
+
else cparams
|
|
898
|
+
),
|
|
899
|
+
)
|
|
900
|
+
if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
|
|
901
|
+
else super().connect(*cargs, **cparams)
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
|
|
905
|
+
@sa_vnt.listens_for(Table, "before_create")
|
|
906
|
+
def check_table(table, connection, _ddl_runner, **kw):
|
|
907
|
+
if isinstance(_ddl_runner.dialect, SnowflakeDialect) and table.indexes:
|
|
908
|
+
raise NotImplementedError("Snowflake does not support indexes")
|
|
909
|
+
|
|
910
|
+
|
|
911
|
+
dialect = SnowflakeDialect
|