snowflake-sqlalchemy 1.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,911 @@
1
+ #
2
+ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ import operator
6
+ from collections import defaultdict
7
+ from functools import reduce
8
+ from urllib.parse import unquote_plus
9
+
10
+ import sqlalchemy.types as sqltypes
11
+ from sqlalchemy import event as sa_vnt
12
+ from sqlalchemy import exc as sa_exc
13
+ from sqlalchemy import util as sa_util
14
+ from sqlalchemy.engine import URL, default, reflection
15
+ from sqlalchemy.schema import Table
16
+ from sqlalchemy.sql import text
17
+ from sqlalchemy.sql.elements import quoted_name
18
+ from sqlalchemy.sql.sqltypes import String
19
+ from sqlalchemy.types import (
20
+ BIGINT,
21
+ BINARY,
22
+ BOOLEAN,
23
+ CHAR,
24
+ DATE,
25
+ DATETIME,
26
+ DECIMAL,
27
+ FLOAT,
28
+ INTEGER,
29
+ REAL,
30
+ SMALLINT,
31
+ TIME,
32
+ TIMESTAMP,
33
+ VARCHAR,
34
+ Date,
35
+ DateTime,
36
+ Float,
37
+ Time,
38
+ )
39
+
40
+ from snowflake.connector import errors as sf_errors
41
+ from snowflake.connector.connection import DEFAULT_CONFIGURATION
42
+ from snowflake.connector.constants import UTF8
43
+
44
+ from .base import (
45
+ SnowflakeCompiler,
46
+ SnowflakeDDLCompiler,
47
+ SnowflakeExecutionContext,
48
+ SnowflakeIdentifierPreparer,
49
+ SnowflakeTypeCompiler,
50
+ )
51
+ from .custom_types import (
52
+ _CUSTOM_DECIMAL,
53
+ ARRAY,
54
+ GEOGRAPHY,
55
+ GEOMETRY,
56
+ OBJECT,
57
+ TIMESTAMP_LTZ,
58
+ TIMESTAMP_NTZ,
59
+ TIMESTAMP_TZ,
60
+ VARIANT,
61
+ _CUSTOM_Date,
62
+ _CUSTOM_DateTime,
63
+ _CUSTOM_Float,
64
+ _CUSTOM_Time,
65
+ )
66
+ from .util import _update_connection_application_name, parse_url_boolean
67
+
68
+ colspecs = {
69
+ Date: _CUSTOM_Date,
70
+ DateTime: _CUSTOM_DateTime,
71
+ Time: _CUSTOM_Time,
72
+ Float: _CUSTOM_Float,
73
+ }
74
+
75
+ ischema_names = {
76
+ "BIGINT": BIGINT,
77
+ "BINARY": BINARY,
78
+ # 'BIT': BIT,
79
+ "BOOLEAN": BOOLEAN,
80
+ "CHAR": CHAR,
81
+ "CHARACTER": CHAR,
82
+ "DATE": DATE,
83
+ "DATETIME": DATETIME,
84
+ "DEC": DECIMAL,
85
+ "DECIMAL": DECIMAL,
86
+ "DOUBLE": FLOAT,
87
+ "FIXED": DECIMAL,
88
+ "FLOAT": FLOAT,
89
+ "INT": INTEGER,
90
+ "INTEGER": INTEGER,
91
+ "NUMBER": _CUSTOM_DECIMAL,
92
+ # 'OBJECT': ?
93
+ "REAL": REAL,
94
+ "BYTEINT": SMALLINT,
95
+ "SMALLINT": SMALLINT,
96
+ "STRING": VARCHAR,
97
+ "TEXT": VARCHAR,
98
+ "TIME": TIME,
99
+ "TIMESTAMP": TIMESTAMP,
100
+ "TIMESTAMP_TZ": TIMESTAMP_TZ,
101
+ "TIMESTAMP_LTZ": TIMESTAMP_LTZ,
102
+ "TIMESTAMP_NTZ": TIMESTAMP_NTZ,
103
+ "TINYINT": SMALLINT,
104
+ "VARBINARY": BINARY,
105
+ "VARCHAR": VARCHAR,
106
+ "VARIANT": VARIANT,
107
+ "OBJECT": OBJECT,
108
+ "ARRAY": ARRAY,
109
+ "GEOGRAPHY": GEOGRAPHY,
110
+ "GEOMETRY": GEOMETRY,
111
+ }
112
+
113
+ _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True
114
+
115
+
116
+ class SnowflakeDialect(default.DefaultDialect):
117
+ name = "snowflake"
118
+ driver = "snowflake"
119
+ max_identifier_length = 255
120
+ cte_follows_insert = True
121
+
122
+ # TODO: support SQL caching, for more info see: https://docs.sqlalchemy.org/en/14/core/connections.html#caching-for-third-party-dialects
123
+ supports_statement_cache = False
124
+
125
+ encoding = UTF8
126
+ default_paramstyle = "pyformat"
127
+ colspecs = colspecs
128
+ ischema_names = ischema_names
129
+
130
+ # all str types must be converted in Unicode
131
+ convert_unicode = True
132
+
133
+ # Indicate whether the DB-API can receive SQL statements as Python
134
+ # unicode strings
135
+ supports_unicode_statements = True
136
+ supports_unicode_binds = True
137
+ returns_unicode_strings = String.RETURNS_UNICODE
138
+ description_encoding = None
139
+
140
+ # No lastrowid support. See SNOW-11155
141
+ postfetch_lastrowid = False
142
+
143
+ # Indicate whether the dialect properly implements rowcount for
144
+ # ``UPDATE`` and ``DELETE`` statements.
145
+ supports_sane_rowcount = True
146
+
147
+ # Indicate whether the dialect properly implements rowcount for
148
+ # ``UPDATE`` and ``DELETE`` statements when executed via
149
+ # executemany.
150
+ supports_sane_multi_rowcount = True
151
+
152
+ # NUMERIC type returns decimal.Decimal
153
+ supports_native_decimal = True
154
+
155
+ # The dialect supports a native boolean construct.
156
+ # This will prevent types.Boolean from generating a CHECK
157
+ # constraint when that type is used.
158
+ supports_native_boolean = True
159
+
160
+ # The dialect supports ``ALTER TABLE``.
161
+ supports_alter = True
162
+
163
+ # The dialect supports CREATE SEQUENCE or similar.
164
+ supports_sequences = True
165
+
166
+ # The dialect supports a native ENUM construct.
167
+ supports_native_enum = False
168
+
169
+ # The dialect supports inserting multiple rows at once.
170
+ supports_multivalues_insert = True
171
+
172
+ # The dialect supports comments
173
+ supports_comments = True
174
+
175
+ preparer = SnowflakeIdentifierPreparer
176
+ ddl_compiler = SnowflakeDDLCompiler
177
+ type_compiler = SnowflakeTypeCompiler
178
+ statement_compiler = SnowflakeCompiler
179
+ execution_ctx_cls = SnowflakeExecutionContext
180
+
181
+ # indicates symbol names are UPPERCASEd if they are case insensitive
182
+ # within the database. If this is True, the methods normalize_name()
183
+ # and denormalize_name() must be provided.
184
+ requires_name_normalize = True
185
+
186
+ multivalues_inserts = True
187
+
188
+ supports_schemas = True
189
+
190
+ sequences_optional = True
191
+
192
+ supports_is_distinct_from = True
193
+
194
+ supports_identity_columns = True
195
+
196
+ @classmethod
197
+ def dbapi(cls):
198
+ from snowflake import connector
199
+
200
+ return connector
201
+
202
+ def create_connect_args(self, url: URL):
203
+ opts = url.translate_connect_args(username="user")
204
+ if "database" in opts:
205
+ name_spaces = [unquote_plus(e) for e in opts["database"].split("/")]
206
+ if len(name_spaces) == 1:
207
+ pass
208
+ elif len(name_spaces) == 2:
209
+ opts["database"] = name_spaces[0]
210
+ opts["schema"] = name_spaces[1]
211
+ else:
212
+ raise sa_exc.ArgumentError(
213
+ f"Invalid name space is specified: {opts['database']}"
214
+ )
215
+ if (
216
+ "host" in opts
217
+ and ".snowflakecomputing.com" not in opts["host"]
218
+ and not opts.get("port")
219
+ ):
220
+ opts["account"] = opts["host"]
221
+ if "." in opts["account"]:
222
+ # remove region subdomain
223
+ opts["account"] = opts["account"][0 : opts["account"].find(".")]
224
+ # remove external ID
225
+ opts["account"] = opts["account"].split("-")[0]
226
+ opts["host"] = opts["host"] + ".snowflakecomputing.com"
227
+ opts["port"] = "443"
228
+ opts["autocommit"] = False # autocommit is disabled by default
229
+
230
+ query = dict(**url.query) # make mutable
231
+ cache_column_metadata = query.pop("cache_column_metadata", None)
232
+ self._cache_column_metadata = (
233
+ parse_url_boolean(cache_column_metadata) if cache_column_metadata else False
234
+ )
235
+
236
+ # URL sets the query parameter values as strings, we need to cast to expected types when necessary
237
+ for name, value in query.items():
238
+ maybe_type_configuration = DEFAULT_CONFIGURATION.get(name)
239
+ if (
240
+ not maybe_type_configuration
241
+ ): # if the parameter is not found in the type mapping, pass it through as a string
242
+ opts[name] = value
243
+ continue
244
+
245
+ (_, expected_type) = maybe_type_configuration
246
+ if not isinstance(expected_type, tuple):
247
+ expected_type = (expected_type,)
248
+
249
+ if isinstance(
250
+ value, expected_type
251
+ ): # if the expected type is str, pass it through as a string
252
+ opts[name] = value
253
+
254
+ elif (
255
+ bool in expected_type
256
+ ): # if the expected type is bool, parse it and pass as a boolean
257
+ opts[name] = parse_url_boolean(value)
258
+ else:
259
+ # TODO: other types like int are stil passed through as string
260
+ # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/447
261
+ opts[name] = value
262
+
263
+ return ([], opts)
264
+
265
+ def has_table(self, connection, table_name, schema=None):
266
+ """
267
+ Checks if the table exists
268
+ """
269
+ return self._has_object(connection, "TABLE", table_name, schema)
270
+
271
+ def has_sequence(self, connection, sequence_name, schema=None):
272
+ """
273
+ Checks if the sequence exists
274
+ """
275
+ return self._has_object(connection, "SEQUENCE", sequence_name, schema)
276
+
277
+ def _has_object(self, connection, object_type, object_name, schema=None):
278
+
279
+ full_name = self._denormalize_quote_join(schema, object_name)
280
+ try:
281
+ results = connection.execute(
282
+ text(f"DESC {object_type} /* sqlalchemy:_has_object */ {full_name}")
283
+ )
284
+ row = results.fetchone()
285
+ have = row is not None
286
+ return have
287
+ except sa_exc.DBAPIError as e:
288
+ if e.orig.__class__ == sf_errors.ProgrammingError:
289
+ return False
290
+ raise
291
+
292
+ def normalize_name(self, name):
293
+ if name is None:
294
+ return None
295
+ if name == "":
296
+ return ""
297
+ if name.upper() == name and not self.identifier_preparer._requires_quotes(
298
+ name.lower()
299
+ ):
300
+ return name.lower()
301
+ elif name.lower() == name:
302
+ return quoted_name(name, quote=True)
303
+ else:
304
+ return name
305
+
306
+ def denormalize_name(self, name):
307
+ if name is None:
308
+ return None
309
+ if name == "":
310
+ return ""
311
+ elif name.lower() == name and not self.identifier_preparer._requires_quotes(
312
+ name.lower()
313
+ ):
314
+ name = name.upper()
315
+ return name
316
+
317
+ def _denormalize_quote_join(self, *idents):
318
+ ip = self.identifier_preparer
319
+ split_idents = reduce(
320
+ operator.add,
321
+ [ip._split_schema_by_dot(ids) for ids in idents if ids is not None],
322
+ )
323
+ return ".".join(ip._quote_free_identifiers(*split_idents))
324
+
325
+ @reflection.cache
326
+ def _current_database_schema(self, connection, **kw):
327
+ res = connection.exec_driver_sql(
328
+ "select current_database(), current_schema();"
329
+ ).fetchone()
330
+ return (
331
+ self.normalize_name(res[0]),
332
+ self.normalize_name(res[1]),
333
+ )
334
+
335
+ def _get_default_schema_name(self, connection):
336
+ # NOTE: no cache object is passed here
337
+ _, current_schema = self._current_database_schema(connection)
338
+ return current_schema
339
+
340
+ @staticmethod
341
+ def _map_name_to_idx(result):
342
+ name_to_idx = {}
343
+ for idx, col in enumerate(result.cursor.description):
344
+ name_to_idx[col[0]] = idx
345
+ return name_to_idx
346
+
347
+ @reflection.cache
348
+ def get_indexes(self, connection, table_name, schema=None, **kw):
349
+ """
350
+ Gets all indexes
351
+ """
352
+ # no index is supported by Snowflake
353
+ return []
354
+
355
+ @reflection.cache
356
+ def get_check_constraints(self, connection, table_name, schema, **kw):
357
+ # check constraints are not supported by Snowflake
358
+ return []
359
+
360
+ @reflection.cache
361
+ def _get_schema_primary_keys(self, connection, schema, **kw):
362
+ result = connection.execute(
363
+ text(
364
+ f"SHOW /* sqlalchemy:_get_schema_primary_keys */PRIMARY KEYS IN SCHEMA {schema}"
365
+ )
366
+ )
367
+ ans = {}
368
+ for row in result:
369
+ table_name = self.normalize_name(row._mapping["table_name"])
370
+ if table_name not in ans:
371
+ ans[table_name] = {
372
+ "constrained_columns": [],
373
+ "name": self.normalize_name(row._mapping["constraint_name"]),
374
+ }
375
+ ans[table_name]["constrained_columns"].append(
376
+ self.normalize_name(row._mapping["column_name"])
377
+ )
378
+ return ans
379
+
380
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
381
+ schema = schema or self.default_schema_name
382
+ current_database, current_schema = self._current_database_schema(
383
+ connection, **kw
384
+ )
385
+ full_schema_name = self._denormalize_quote_join(
386
+ current_database, schema if schema else current_schema
387
+ )
388
+ return self._get_schema_primary_keys(
389
+ connection, self.denormalize_name(full_schema_name), **kw
390
+ ).get(table_name, {"constrained_columns": [], "name": None})
391
+
392
+ @reflection.cache
393
+ def _get_schema_unique_constraints(self, connection, schema, **kw):
394
+ result = connection.execute(
395
+ text(
396
+ f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN SCHEMA {schema}"
397
+ )
398
+ )
399
+ unique_constraints = {}
400
+ for row in result:
401
+ name = self.normalize_name(row._mapping["constraint_name"])
402
+ if name not in unique_constraints:
403
+ unique_constraints[name] = {
404
+ "column_names": [self.normalize_name(row._mapping["column_name"])],
405
+ "name": name,
406
+ "table_name": self.normalize_name(row._mapping["table_name"]),
407
+ }
408
+ else:
409
+ unique_constraints[name]["column_names"].append(
410
+ self.normalize_name(row._mapping["column_name"])
411
+ )
412
+
413
+ ans = defaultdict(list)
414
+ for constraint in unique_constraints.values():
415
+ table_name = constraint.pop("table_name")
416
+ ans[table_name].append(constraint)
417
+ return ans
418
+
419
+ def get_unique_constraints(self, connection, table_name, schema, **kw):
420
+ schema = schema or self.default_schema_name
421
+ current_database, current_schema = self._current_database_schema(
422
+ connection, **kw
423
+ )
424
+ full_schema_name = self._denormalize_quote_join(
425
+ current_database, schema if schema else current_schema
426
+ )
427
+ return self._get_schema_unique_constraints(
428
+ connection, self.denormalize_name(full_schema_name), **kw
429
+ ).get(table_name, [])
430
+
431
+ @reflection.cache
432
+ def _get_schema_foreign_keys(self, connection, schema, **kw):
433
+ _, current_schema = self._current_database_schema(connection, **kw)
434
+ result = connection.execute(
435
+ text(
436
+ f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN SCHEMA {schema}"
437
+ )
438
+ )
439
+ foreign_key_map = {}
440
+ for row in result:
441
+ name = self.normalize_name(row._mapping["fk_name"])
442
+ if name not in foreign_key_map:
443
+ referred_schema = self.normalize_name(row._mapping["pk_schema_name"])
444
+ foreign_key_map[name] = {
445
+ "constrained_columns": [
446
+ self.normalize_name(row._mapping["fk_column_name"])
447
+ ],
448
+ # referred schema should be None in context where it doesn't need to be specified
449
+ # https://docs.sqlalchemy.org/en/14/core/reflection.html#reflection-schema-qualified-interaction
450
+ "referred_schema": (
451
+ referred_schema
452
+ if referred_schema
453
+ not in (self.default_schema_name, current_schema)
454
+ else None
455
+ ),
456
+ "referred_table": self.normalize_name(
457
+ row._mapping["pk_table_name"]
458
+ ),
459
+ "referred_columns": [
460
+ self.normalize_name(row._mapping["pk_column_name"])
461
+ ],
462
+ "name": name,
463
+ "table_name": self.normalize_name(row._mapping["fk_table_name"]),
464
+ }
465
+ options = {}
466
+ if self.normalize_name(row._mapping["delete_rule"]) != "NO ACTION":
467
+ options["ondelete"] = self.normalize_name(
468
+ row._mapping["delete_rule"]
469
+ )
470
+ if self.normalize_name(row._mapping["update_rule"]) != "NO ACTION":
471
+ options["onupdate"] = self.normalize_name(
472
+ row._mapping["update_rule"]
473
+ )
474
+ foreign_key_map[name]["options"] = options
475
+ else:
476
+ foreign_key_map[name]["constrained_columns"].append(
477
+ self.normalize_name(row._mapping["fk_column_name"])
478
+ )
479
+ foreign_key_map[name]["referred_columns"].append(
480
+ self.normalize_name(row._mapping["pk_column_name"])
481
+ )
482
+
483
+ ans = {}
484
+
485
+ for _, v in foreign_key_map.items():
486
+ if v["table_name"] not in ans:
487
+ ans[v["table_name"]] = []
488
+ ans[v["table_name"]].append(
489
+ {k2: v2 for k2, v2 in v.items() if k2 != "table_name"}
490
+ )
491
+ return ans
492
+
493
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
494
+ """
495
+ Gets all foreign keys for a table
496
+ """
497
+ schema = schema or self.default_schema_name
498
+ current_database, current_schema = self._current_database_schema(
499
+ connection, **kw
500
+ )
501
+ full_schema_name = self._denormalize_quote_join(
502
+ current_database, schema if schema else current_schema
503
+ )
504
+
505
+ foreign_key_map = self._get_schema_foreign_keys(
506
+ connection, self.denormalize_name(full_schema_name), **kw
507
+ )
508
+ return foreign_key_map.get(table_name, [])
509
+
510
+ @reflection.cache
511
+ def _get_schema_columns(self, connection, schema, **kw):
512
+ """Get all columns in the schema, if we hit 'Information schema query returned too much data' problem return
513
+ None, as it is cacheable and is an unexpected return type for this function"""
514
+ ans = {}
515
+ current_database, _ = self._current_database_schema(connection, **kw)
516
+ full_schema_name = self._denormalize_quote_join(current_database, schema)
517
+ try:
518
+ schema_primary_keys = self._get_schema_primary_keys(
519
+ connection, full_schema_name, **kw
520
+ )
521
+ result = connection.execute(
522
+ text(
523
+ """
524
+ SELECT /* sqlalchemy:_get_schema_columns */
525
+ ic.table_name,
526
+ ic.column_name,
527
+ ic.data_type,
528
+ ic.character_maximum_length,
529
+ ic.numeric_precision,
530
+ ic.numeric_scale,
531
+ ic.is_nullable,
532
+ ic.column_default,
533
+ ic.is_identity,
534
+ ic.comment,
535
+ ic.identity_start,
536
+ ic.identity_increment
537
+ FROM information_schema.columns ic
538
+ WHERE ic.table_schema=:table_schema
539
+ ORDER BY ic.ordinal_position"""
540
+ ),
541
+ {"table_schema": self.denormalize_name(schema)},
542
+ )
543
+ except sa_exc.ProgrammingError as pe:
544
+ if pe.orig.errno == 90030:
545
+ # This means that there are too many tables in the schema, we need to go more granular
546
+ return None # None triggers _get_table_columns while staying cacheable
547
+ raise
548
+ for (
549
+ table_name,
550
+ column_name,
551
+ coltype,
552
+ character_maximum_length,
553
+ numeric_precision,
554
+ numeric_scale,
555
+ is_nullable,
556
+ column_default,
557
+ is_identity,
558
+ comment,
559
+ identity_start,
560
+ identity_increment,
561
+ ) in result:
562
+ table_name = self.normalize_name(table_name)
563
+ column_name = self.normalize_name(column_name)
564
+ if table_name not in ans:
565
+ ans[table_name] = list()
566
+ if column_name.startswith("sys_clustering_column"):
567
+ continue # ignoring clustering column
568
+ col_type = self.ischema_names.get(coltype, None)
569
+ col_type_kw = {}
570
+ if col_type is None:
571
+ sa_util.warn(
572
+ f"Did not recognize type '{coltype}' of column '{column_name}'"
573
+ )
574
+ col_type = sqltypes.NULLTYPE
575
+ else:
576
+ if issubclass(col_type, FLOAT):
577
+ col_type_kw["precision"] = numeric_precision
578
+ col_type_kw["decimal_return_scale"] = numeric_scale
579
+ elif issubclass(col_type, sqltypes.Numeric):
580
+ col_type_kw["precision"] = numeric_precision
581
+ col_type_kw["scale"] = numeric_scale
582
+ elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)):
583
+ col_type_kw["length"] = character_maximum_length
584
+
585
+ type_instance = col_type(**col_type_kw)
586
+
587
+ current_table_pks = schema_primary_keys.get(table_name)
588
+
589
+ ans[table_name].append(
590
+ {
591
+ "name": column_name,
592
+ "type": type_instance,
593
+ "nullable": is_nullable == "YES",
594
+ "default": column_default,
595
+ "autoincrement": is_identity == "YES",
596
+ "comment": comment,
597
+ "primary_key": (
598
+ (
599
+ column_name
600
+ in schema_primary_keys[table_name]["constrained_columns"]
601
+ )
602
+ if current_table_pks
603
+ else False
604
+ ),
605
+ }
606
+ )
607
+ if is_identity == "YES":
608
+ ans[table_name][-1]["identity"] = {
609
+ "start": identity_start,
610
+ "increment": identity_increment,
611
+ }
612
+ return ans
613
+
614
+ @reflection.cache
615
+ def _get_table_columns(self, connection, table_name, schema=None, **kw):
616
+ """Get all columns in a table in a schema"""
617
+ ans = []
618
+ current_database, _ = self._current_database_schema(connection, **kw)
619
+ full_schema_name = self._denormalize_quote_join(current_database, schema)
620
+ schema_primary_keys = self._get_schema_primary_keys(
621
+ connection, full_schema_name, **kw
622
+ )
623
+ result = connection.execute(
624
+ text(
625
+ """
626
+ SELECT /* sqlalchemy:get_table_columns */
627
+ ic.table_name,
628
+ ic.column_name,
629
+ ic.data_type,
630
+ ic.character_maximum_length,
631
+ ic.numeric_precision,
632
+ ic.numeric_scale,
633
+ ic.is_nullable,
634
+ ic.column_default,
635
+ ic.is_identity,
636
+ ic.comment
637
+ FROM information_schema.columns ic
638
+ WHERE ic.table_schema=:table_schema
639
+ AND ic.table_name=:table_name
640
+ ORDER BY ic.ordinal_position"""
641
+ ),
642
+ {
643
+ "table_schema": self.denormalize_name(schema),
644
+ "table_name": self.denormalize_name(table_name),
645
+ },
646
+ )
647
+ for (
648
+ table_name,
649
+ column_name,
650
+ coltype,
651
+ character_maximum_length,
652
+ numeric_precision,
653
+ numeric_scale,
654
+ is_nullable,
655
+ column_default,
656
+ is_identity,
657
+ comment,
658
+ ) in result:
659
+ table_name = self.normalize_name(table_name)
660
+ column_name = self.normalize_name(column_name)
661
+ if column_name.startswith("sys_clustering_column"):
662
+ continue # ignoring clustering column
663
+ col_type = self.ischema_names.get(coltype, None)
664
+ col_type_kw = {}
665
+ if col_type is None:
666
+ sa_util.warn(
667
+ f"Did not recognize type '{coltype}' of column '{column_name}'"
668
+ )
669
+ col_type = sqltypes.NULLTYPE
670
+ else:
671
+ if issubclass(col_type, FLOAT):
672
+ col_type_kw["precision"] = numeric_precision
673
+ col_type_kw["decimal_return_scale"] = numeric_scale
674
+ elif issubclass(col_type, sqltypes.Numeric):
675
+ col_type_kw["precision"] = numeric_precision
676
+ col_type_kw["scale"] = numeric_scale
677
+ elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)):
678
+ col_type_kw["length"] = character_maximum_length
679
+
680
+ type_instance = col_type(**col_type_kw)
681
+
682
+ current_table_pks = schema_primary_keys.get(table_name)
683
+
684
+ ans.append(
685
+ {
686
+ "name": column_name,
687
+ "type": type_instance,
688
+ "nullable": is_nullable == "YES",
689
+ "default": column_default,
690
+ "autoincrement": is_identity == "YES",
691
+ "comment": comment if comment != "" else None,
692
+ "primary_key": (
693
+ (
694
+ column_name
695
+ in schema_primary_keys[table_name]["constrained_columns"]
696
+ )
697
+ if current_table_pks
698
+ else False
699
+ ),
700
+ }
701
+ )
702
+
703
+ # If we didn't find any columns for the table, the table doesn't exist.
704
+ if len(ans) == 0:
705
+ raise sa_exc.NoSuchTableError()
706
+ return ans
707
+
708
+ def get_columns(self, connection, table_name, schema=None, **kw):
709
+ """
710
+ Gets all column info given the table info
711
+ """
712
+ schema = schema or self.default_schema_name
713
+ if not schema:
714
+ _, schema = self._current_database_schema(connection, **kw)
715
+
716
+ schema_columns = self._get_schema_columns(connection, schema, **kw)
717
+ if schema_columns is None:
718
+ # Too many results, fall back to only query about single table
719
+ return self._get_table_columns(connection, table_name, schema, **kw)
720
+ normalized_table_name = self.normalize_name(table_name)
721
+ if normalized_table_name not in schema_columns:
722
+ raise sa_exc.NoSuchTableError()
723
+ return schema_columns[normalized_table_name]
724
+
725
+ @reflection.cache
726
+ def get_table_names(self, connection, schema=None, **kw):
727
+ """
728
+ Gets all table names.
729
+ """
730
+ schema = schema or self.default_schema_name
731
+ current_schema = schema
732
+ if schema:
733
+ cursor = connection.execute(
734
+ text(
735
+ f"SHOW /* sqlalchemy:get_table_names */ TABLES IN {self._denormalize_quote_join(schema)}"
736
+ )
737
+ )
738
+ else:
739
+ cursor = connection.execute(
740
+ text("SHOW /* sqlalchemy:get_table_names */ TABLES")
741
+ )
742
+ _, current_schema = self._current_database_schema(connection)
743
+
744
+ ret = [self.normalize_name(row[1]) for row in cursor]
745
+
746
+ return ret
747
+
748
+ @reflection.cache
749
+ def get_view_names(self, connection, schema=None, **kw):
750
+ """
751
+ Gets all view names
752
+ """
753
+ schema = schema or self.default_schema_name
754
+ if schema:
755
+ cursor = connection.execute(
756
+ text(
757
+ f"SHOW /* sqlalchemy:get_view_names */ VIEWS IN {self._denormalize_quote_join(schema)}"
758
+ )
759
+ )
760
+ else:
761
+ cursor = connection.execute(
762
+ text("SHOW /* sqlalchemy:get_view_names */ VIEWS")
763
+ )
764
+
765
+ return [self.normalize_name(row[1]) for row in cursor]
766
+
767
+ @reflection.cache
768
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
769
+ """
770
+ Gets the view definition
771
+ """
772
+ schema = schema or self.default_schema_name
773
+ if schema:
774
+ cursor = connection.execute(
775
+ text(
776
+ f"SHOW /* sqlalchemy:get_view_definition */ VIEWS \
777
+ LIKE '{self._denormalize_quote_join(view_name)}' IN {self._denormalize_quote_join(schema)}"
778
+ )
779
+ )
780
+ else:
781
+ cursor = connection.execute(
782
+ text(
783
+ f"SHOW /* sqlalchemy:get_view_definition */ VIEWS \
784
+ LIKE '{self._denormalize_quote_join(view_name)}'"
785
+ )
786
+ )
787
+
788
+ n2i = self.__class__._map_name_to_idx(cursor)
789
+ try:
790
+ ret = cursor.fetchone()
791
+ if ret:
792
+ return ret[n2i["text"]]
793
+ except Exception:
794
+ pass
795
+ return None
796
+
797
+ def get_temp_table_names(self, connection, schema=None, **kw):
798
+ schema = schema or self.default_schema_name
799
+ if schema:
800
+ cursor = connection.execute(
801
+ text(
802
+ f"SHOW /* sqlalchemy:get_temp_table_names */ TABLES \
803
+ IN {self._denormalize_quote_join(schema)}"
804
+ )
805
+ )
806
+ else:
807
+ cursor = connection.execute(
808
+ text("SHOW /* sqlalchemy:get_temp_table_names */ TABLES")
809
+ )
810
+
811
+ ret = []
812
+ n2i = self.__class__._map_name_to_idx(cursor)
813
+ for row in cursor:
814
+ if row[n2i["kind"]] == "TEMPORARY":
815
+ ret.append(self.normalize_name(row[n2i["name"]]))
816
+
817
+ return ret
818
+
819
+ def get_schema_names(self, connection, **kw):
820
+ """
821
+ Gets all schema names.
822
+ """
823
+ cursor = connection.execute(
824
+ text("SHOW /* sqlalchemy:get_schema_names */ SCHEMAS")
825
+ )
826
+
827
+ return [self.normalize_name(row[1]) for row in cursor]
828
+
829
+ @reflection.cache
830
+ def get_sequence_names(self, connection, schema=None, **kw):
831
+ sql_command = "SHOW SEQUENCES {}".format(
832
+ f"IN SCHEMA {self.normalize_name(schema)}" if schema else ""
833
+ )
834
+ try:
835
+ cursor = connection.execute(text(sql_command))
836
+ return [self.normalize_name(row[0]) for row in cursor]
837
+ except sa_exc.ProgrammingError as pe:
838
+ if pe.orig.errno == 2003:
839
+ # Schema does not exist
840
+ return []
841
+
842
+ def _get_table_comment(self, connection, table_name, schema=None, **kw):
843
+ """
844
+ Returns comment of table in a dictionary as described by SQLAlchemy spec.
845
+ """
846
+ sql_command = (
847
+ "SHOW /* sqlalchemy:_get_table_comment */ "
848
+ "TABLES LIKE '{}'{}".format(
849
+ table_name,
850
+ f" IN SCHEMA {self.normalize_name(schema)}" if schema else "",
851
+ )
852
+ )
853
+ cursor = connection.execute(text(sql_command))
854
+ return cursor.fetchone()
855
+
856
+ def _get_view_comment(self, connection, table_name, schema=None, **kw):
857
+ """
858
+ Returns comment of view in a dictionary as described by SQLAlchemy spec.
859
+ """
860
+ sql_command = (
861
+ "SHOW /* sqlalchemy:_get_view_comment */ "
862
+ "VIEWS LIKE '{}'{}".format(
863
+ table_name,
864
+ f" IN SCHEMA {self.normalize_name(schema)}" if schema else "",
865
+ )
866
+ )
867
+ cursor = connection.execute(text(sql_command))
868
+ return cursor.fetchone()
869
+
870
+ def get_table_comment(self, connection, table_name, schema=None, **kw):
871
+ """
872
+ Returns comment associated with a table (or view) in a dictionary as
873
+ SQLAlchemy expects. Note that since SQLAlchemy may not (in fact,
874
+ typically does not) know if this is a table or a view, we have to
875
+ handle both cases here.
876
+ """
877
+ result = self._get_table_comment(connection, table_name, schema)
878
+ if result is None:
879
+ # the "table" being reflected is actually a view
880
+ result = self._get_view_comment(connection, table_name, schema)
881
+
882
+ return {
883
+ "text": (
884
+ result._mapping["comment"]
885
+ if result and result._mapping["comment"]
886
+ else None
887
+ )
888
+ }
889
+
890
+ def connect(self, *cargs, **cparams):
891
+ return (
892
+ super().connect(
893
+ *cargs,
894
+ **(
895
+ _update_connection_application_name(**cparams)
896
+ if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
897
+ else cparams
898
+ ),
899
+ )
900
+ if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
901
+ else super().connect(*cargs, **cparams)
902
+ )
903
+
904
+
905
+ @sa_vnt.listens_for(Table, "before_create")
906
+ def check_table(table, connection, _ddl_runner, **kw):
907
+ if isinstance(_ddl_runner.dialect, SnowflakeDialect) and table.indexes:
908
+ raise NotImplementedError("Snowflake does not support indexes")
909
+
910
+
911
+ dialect = SnowflakeDialect