snowflake-sqlalchemy 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. snowflake/sqlalchemy/__init__.py +162 -0
  2. snowflake/sqlalchemy/_constants.py +14 -0
  3. snowflake/sqlalchemy/base.py +1188 -0
  4. snowflake/sqlalchemy/compat.py +36 -0
  5. snowflake/sqlalchemy/custom_commands.py +627 -0
  6. snowflake/sqlalchemy/custom_types.py +155 -0
  7. snowflake/sqlalchemy/exc.py +82 -0
  8. snowflake/sqlalchemy/functions.py +16 -0
  9. snowflake/sqlalchemy/parser/custom_type_parser.py +245 -0
  10. snowflake/sqlalchemy/provision.py +12 -0
  11. snowflake/sqlalchemy/requirements.py +313 -0
  12. snowflake/sqlalchemy/snowdialect.py +1029 -0
  13. snowflake/sqlalchemy/sql/__init__.py +3 -0
  14. snowflake/sqlalchemy/sql/custom_schema/__init__.py +9 -0
  15. snowflake/sqlalchemy/sql/custom_schema/clustered_table.py +37 -0
  16. snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py +127 -0
  17. snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py +13 -0
  18. snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py +117 -0
  19. snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py +63 -0
  20. snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py +102 -0
  21. snowflake/sqlalchemy/sql/custom_schema/options/__init__.py +33 -0
  22. snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py +63 -0
  23. snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py +58 -0
  24. snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py +63 -0
  25. snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py +25 -0
  26. snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py +65 -0
  27. snowflake/sqlalchemy/sql/custom_schema/options/keywords.py +14 -0
  28. snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py +67 -0
  29. snowflake/sqlalchemy/sql/custom_schema/options/table_option.py +84 -0
  30. snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py +94 -0
  31. snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py +70 -0
  32. snowflake/sqlalchemy/sql/custom_schema/table_from_query.py +54 -0
  33. snowflake/sqlalchemy/util.py +344 -0
  34. snowflake/sqlalchemy/version.py +6 -0
  35. snowflake_sqlalchemy-1.7.3.dist-info/METADATA +737 -0
  36. snowflake_sqlalchemy-1.7.3.dist-info/RECORD +39 -0
  37. snowflake_sqlalchemy-1.7.3.dist-info/WHEEL +4 -0
  38. snowflake_sqlalchemy-1.7.3.dist-info/entry_points.txt +2 -0
  39. snowflake_sqlalchemy-1.7.3.dist-info/licenses/LICENSE.txt +202 -0
@@ -0,0 +1,1029 @@
1
+ #
2
+ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+ import operator
5
+ import re
6
+ from collections import defaultdict
7
+ from enum import Enum
8
+ from functools import reduce
9
+ from typing import Any, Collection, Optional
10
+ from urllib.parse import unquote_plus
11
+
12
+ import sqlalchemy.sql.sqltypes as sqltypes
13
+ from sqlalchemy import event as sa_vnt
14
+ from sqlalchemy import exc as sa_exc
15
+ from sqlalchemy import util as sa_util
16
+ from sqlalchemy.engine import URL, default, reflection
17
+ from sqlalchemy.schema import Table
18
+ from sqlalchemy.sql import text
19
+ from sqlalchemy.sql.elements import quoted_name
20
+ from sqlalchemy.sql.sqltypes import NullType
21
+ from sqlalchemy.types import FLOAT, Date, DateTime, Float, Time
22
+
23
+ from snowflake.connector import errors as sf_errors
24
+ from snowflake.connector.connection import DEFAULT_CONFIGURATION
25
+ from snowflake.connector.constants import UTF8
26
+ from snowflake.sqlalchemy.compat import returns_unicode
27
+
28
+ from ._constants import DIALECT_NAME
29
+ from .base import (
30
+ SnowflakeCompiler,
31
+ SnowflakeDDLCompiler,
32
+ SnowflakeExecutionContext,
33
+ SnowflakeIdentifierPreparer,
34
+ SnowflakeTypeCompiler,
35
+ )
36
+ from .custom_types import (
37
+ StructuredType,
38
+ _CUSTOM_Date,
39
+ _CUSTOM_DateTime,
40
+ _CUSTOM_Float,
41
+ _CUSTOM_Time,
42
+ )
43
+ from .parser.custom_type_parser import * # noqa
44
+ from .parser.custom_type_parser import _CUSTOM_DECIMAL # noqa
45
+ from .parser.custom_type_parser import ischema_names, parse_index_columns, parse_type
46
+ from .sql.custom_schema.custom_table_prefix import CustomTablePrefix
47
+ from .util import (
48
+ _update_connection_application_name,
49
+ parse_url_boolean,
50
+ parse_url_integer,
51
+ )
52
+
53
+ colspecs = {
54
+ Date: _CUSTOM_Date,
55
+ DateTime: _CUSTOM_DateTime,
56
+ Time: _CUSTOM_Time,
57
+ Float: _CUSTOM_Float,
58
+ }
59
+
60
+ _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True
61
+
62
+
63
+ class SnowflakeIsolationLevel(Enum):
64
+ READ_COMMITTED = "READ COMMITTED"
65
+ AUTOCOMMIT = "AUTOCOMMIT"
66
+
67
+
68
+ class SnowflakeDialect(default.DefaultDialect):
69
+ name = DIALECT_NAME
70
+ driver = "snowflake"
71
+ max_identifier_length = 255
72
+ cte_follows_insert = True
73
+
74
+ # TODO: support SQL caching, for more info see: https://docs.sqlalchemy.org/en/14/core/connections.html#caching-for-third-party-dialects
75
+ supports_statement_cache = False
76
+
77
+ encoding = UTF8
78
+ default_paramstyle = "pyformat"
79
+ colspecs = colspecs
80
+ ischema_names = ischema_names
81
+
82
+ # target database treats the / division operator as “floor division”
83
+ div_is_floordiv = False
84
+
85
+ # all str types must be converted in Unicode
86
+ convert_unicode = True
87
+
88
+ # Indicate whether the DB-API can receive SQL statements as Python
89
+ # unicode strings
90
+ supports_unicode_statements = True
91
+ supports_unicode_binds = True
92
+ returns_unicode_strings = returns_unicode
93
+ description_encoding = None
94
+
95
+ # No lastrowid support. See SNOW-11155
96
+ postfetch_lastrowid = False
97
+
98
+ # Indicate whether the dialect properly implements rowcount for
99
+ # ``UPDATE`` and ``DELETE`` statements.
100
+ supports_sane_rowcount = True
101
+
102
+ # Indicate whether the dialect properly implements rowcount for
103
+ # ``UPDATE`` and ``DELETE`` statements when executed via
104
+ # executemany.
105
+ supports_sane_multi_rowcount = True
106
+
107
+ # NUMERIC type returns decimal.Decimal
108
+ supports_native_decimal = True
109
+
110
+ # The dialect supports a native boolean construct.
111
+ # This will prevent types.Boolean from generating a CHECK
112
+ # constraint when that type is used.
113
+ supports_native_boolean = True
114
+
115
+ # The dialect supports ``ALTER TABLE``.
116
+ supports_alter = True
117
+
118
+ # The dialect supports CREATE SEQUENCE or similar.
119
+ supports_sequences = True
120
+
121
+ # The dialect supports a native ENUM construct.
122
+ supports_native_enum = False
123
+
124
+ # The dialect supports inserting multiple rows at once.
125
+ supports_multivalues_insert = True
126
+
127
+ # The dialect supports comments
128
+ supports_comments = True
129
+
130
+ preparer = SnowflakeIdentifierPreparer
131
+ ddl_compiler = SnowflakeDDLCompiler
132
+ type_compiler = SnowflakeTypeCompiler
133
+ statement_compiler = SnowflakeCompiler
134
+ execution_ctx_cls = SnowflakeExecutionContext
135
+
136
+ # indicates symbol names are UPPERCASEd if they are case insensitive
137
+ # within the database. If this is True, the methods normalize_name()
138
+ # and denormalize_name() must be provided.
139
+ requires_name_normalize = True
140
+
141
+ multivalues_inserts = True
142
+
143
+ supports_schemas = True
144
+
145
+ sequences_optional = True
146
+
147
+ supports_is_distinct_from = True
148
+
149
+ supports_identity_columns = True
150
+
151
+ def __init__(
152
+ self,
153
+ force_div_is_floordiv: bool = True,
154
+ isolation_level: Optional[str] = SnowflakeIsolationLevel.READ_COMMITTED.value,
155
+ **kwargs: Any,
156
+ ):
157
+ super().__init__(isolation_level=isolation_level, **kwargs)
158
+ self.force_div_is_floordiv = force_div_is_floordiv
159
+ self.div_is_floordiv = force_div_is_floordiv
160
+
161
+ def initialize(self, connection):
162
+ super().initialize(connection)
163
+ self.div_is_floordiv = self.force_div_is_floordiv
164
+
165
+ @classmethod
166
+ def dbapi(cls):
167
+ return cls.import_dbapi()
168
+
169
+ @classmethod
170
+ def import_dbapi(cls):
171
+ from snowflake import connector
172
+
173
+ return connector
174
+
175
+ @staticmethod
176
+ def parse_query_param_type(name: str, value: Any) -> Any:
177
+ """Cast param value if possible to type defined in connector-python."""
178
+ if not (maybe_type_configuration := DEFAULT_CONFIGURATION.get(name)):
179
+ return value
180
+
181
+ _, expected_type = maybe_type_configuration
182
+ if not isinstance(expected_type, tuple):
183
+ expected_type = (expected_type,)
184
+
185
+ if isinstance(value, expected_type):
186
+ return value
187
+
188
+ elif bool in expected_type:
189
+ return parse_url_boolean(value)
190
+ elif int in expected_type:
191
+ return parse_url_integer(value)
192
+ else:
193
+ return value
194
+
195
+ def create_connect_args(self, url: URL):
196
+ opts = url.translate_connect_args(username="user")
197
+ if "database" in opts:
198
+ name_spaces = [unquote_plus(e) for e in opts["database"].split("/")]
199
+ if len(name_spaces) == 1:
200
+ pass
201
+ elif len(name_spaces) == 2:
202
+ opts["database"] = name_spaces[0]
203
+ opts["schema"] = name_spaces[1]
204
+ else:
205
+ raise sa_exc.ArgumentError(
206
+ f"Invalid name space is specified: {opts['database']}"
207
+ )
208
+ if (
209
+ "host" in opts
210
+ and ".snowflakecomputing.com" not in opts["host"]
211
+ and not opts.get("port")
212
+ ):
213
+ opts["account"] = opts["host"]
214
+ if "." in opts["account"]:
215
+ # remove region subdomain
216
+ opts["account"] = opts["account"][0 : opts["account"].find(".")]
217
+ # remove external ID
218
+ opts["account"] = opts["account"].split("-")[0]
219
+ opts["host"] = opts["host"] + ".snowflakecomputing.com"
220
+ opts["port"] = "443"
221
+ opts["autocommit"] = False # autocommit is disabled by default
222
+
223
+ query = dict(**url.query) # make mutable
224
+ cache_column_metadata = query.pop("cache_column_metadata", None)
225
+ self._cache_column_metadata = (
226
+ parse_url_boolean(cache_column_metadata) if cache_column_metadata else False
227
+ )
228
+
229
+ # URL sets the query parameter values as strings, we need to cast to expected types when necessary
230
+ for name, value in query.items():
231
+ opts[name] = self.parse_query_param_type(name, value)
232
+
233
+ return ([], opts)
234
+
235
+ @reflection.cache
236
+ def has_table(self, connection, table_name, schema=None, **kw):
237
+ """
238
+ Checks if the table exists
239
+ """
240
+ return self._has_object(connection, "TABLE", table_name, schema)
241
+
242
+ def get_isolation_level_values(self, dbapi_connection):
243
+ return [
244
+ SnowflakeIsolationLevel.READ_COMMITTED.value,
245
+ SnowflakeIsolationLevel.AUTOCOMMIT.value,
246
+ ]
247
+
248
+ def do_rollback(self, dbapi_connection):
249
+ dbapi_connection.rollback()
250
+
251
+ def do_commit(self, dbapi_connection):
252
+ dbapi_connection.commit()
253
+
254
+ def get_default_isolation_level(self, dbapi_conn):
255
+ return SnowflakeIsolationLevel.READ_COMMITTED.value
256
+
257
+ def set_isolation_level(self, dbapi_connection, level):
258
+ if level == SnowflakeIsolationLevel.AUTOCOMMIT.value:
259
+ dbapi_connection.autocommit(True)
260
+ else:
261
+ dbapi_connection.autocommit(False)
262
+
263
+ @reflection.cache
264
+ def has_sequence(self, connection, sequence_name, schema=None, **kw):
265
+ """
266
+ Checks if the sequence exists
267
+ """
268
+ return self._has_object(connection, "SEQUENCE", sequence_name, schema)
269
+
270
+ def _has_object(self, connection, object_type, object_name, schema=None):
271
+ full_name = self._denormalize_quote_join(schema, object_name)
272
+ try:
273
+ results = connection.execute(
274
+ text(f"DESC {object_type} /* sqlalchemy:_has_object */ {full_name}")
275
+ )
276
+ row = results.fetchone()
277
+ have = row is not None
278
+ return have
279
+ except sa_exc.DBAPIError as e:
280
+ if e.orig.__class__ == sf_errors.ProgrammingError:
281
+ return False
282
+ raise
283
+
284
+ def normalize_name(self, name):
285
+ if name is None:
286
+ return None
287
+ if name == "":
288
+ return ""
289
+ if name.upper() == name and not self.identifier_preparer._requires_quotes(
290
+ name.lower()
291
+ ):
292
+ return name.lower()
293
+ elif name.lower() == name:
294
+ return quoted_name(name, quote=True)
295
+ else:
296
+ return name
297
+
298
+ def denormalize_name(self, name):
299
+ if name is None:
300
+ return None
301
+ if name == "":
302
+ return ""
303
+ elif name.lower() == name and not self.identifier_preparer._requires_quotes(
304
+ name.lower()
305
+ ):
306
+ name = name.upper()
307
+ return name
308
+
309
+ def _denormalize_quote_join(self, *idents):
310
+ ip = self.identifier_preparer
311
+ split_idents = reduce(
312
+ operator.add,
313
+ [ip._split_schema_by_dot(ids) for ids in idents if ids is not None],
314
+ )
315
+ return ".".join(ip._quote_free_identifiers(*split_idents))
316
+
317
+ @reflection.cache
318
+ def _current_database_schema(self, connection, **kw):
319
+ res = connection.execute(
320
+ text("select current_database(), current_schema();")
321
+ ).fetchone()
322
+ return (
323
+ self.normalize_name(res[0]),
324
+ self.normalize_name(res[1]),
325
+ )
326
+
327
+ def _get_default_schema_name(self, connection):
328
+ # NOTE: no cache object is passed here
329
+ _, current_schema = self._current_database_schema(connection)
330
+ return current_schema
331
+
332
+ @staticmethod
333
+ def _map_name_to_idx(result):
334
+ name_to_idx = {}
335
+ for idx, col in enumerate(result.cursor.description):
336
+ name_to_idx[col[0]] = idx
337
+ return name_to_idx
338
+
339
+ @reflection.cache
340
+ def get_check_constraints(self, connection, table_name, schema, **kw):
341
+ # check constraints are not supported by Snowflake
342
+ return []
343
+
344
+ @reflection.cache
345
+ def _get_schema_primary_keys(self, connection, schema, **kw):
346
+ result = connection.execute(
347
+ text(
348
+ f"SHOW /* sqlalchemy:_get_schema_primary_keys */PRIMARY KEYS IN SCHEMA {schema}"
349
+ )
350
+ )
351
+ ans = {}
352
+ for row in result:
353
+ table_name = self.normalize_name(row._mapping["table_name"])
354
+ if table_name not in ans:
355
+ ans[table_name] = {
356
+ "constrained_columns": [],
357
+ "name": self.normalize_name(row._mapping["constraint_name"]),
358
+ }
359
+ ans[table_name]["constrained_columns"].append(
360
+ self.normalize_name(row._mapping["column_name"])
361
+ )
362
+ return ans
363
+
364
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
365
+ schema = schema or self.default_schema_name
366
+ current_database, current_schema = self._current_database_schema(
367
+ connection, **kw
368
+ )
369
+ full_schema_name = self._denormalize_quote_join(
370
+ current_database, schema if schema else current_schema
371
+ )
372
+ return self._get_schema_primary_keys(
373
+ connection, self.denormalize_name(full_schema_name), **kw
374
+ ).get(table_name, {"constrained_columns": [], "name": None})
375
+
376
+ @reflection.cache
377
+ def _get_schema_unique_constraints(self, connection, schema, **kw):
378
+ result = connection.execute(
379
+ text(
380
+ f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN SCHEMA {schema}"
381
+ )
382
+ )
383
+ unique_constraints = {}
384
+ for row in result:
385
+ name = self.normalize_name(row._mapping["constraint_name"])
386
+ if name not in unique_constraints:
387
+ unique_constraints[name] = {
388
+ "column_names": [self.normalize_name(row._mapping["column_name"])],
389
+ "name": name,
390
+ "table_name": self.normalize_name(row._mapping["table_name"]),
391
+ }
392
+ else:
393
+ unique_constraints[name]["column_names"].append(
394
+ self.normalize_name(row._mapping["column_name"])
395
+ )
396
+
397
+ ans = defaultdict(list)
398
+ for constraint in unique_constraints.values():
399
+ table_name = constraint.pop("table_name")
400
+ ans[table_name].append(constraint)
401
+ return ans
402
+
403
+ def get_unique_constraints(self, connection, table_name, schema, **kw):
404
+ schema = schema or self.default_schema_name
405
+ current_database, current_schema = self._current_database_schema(
406
+ connection, **kw
407
+ )
408
+ full_schema_name = self._denormalize_quote_join(
409
+ current_database, schema if schema else current_schema
410
+ )
411
+ return self._get_schema_unique_constraints(
412
+ connection, self.denormalize_name(full_schema_name), **kw
413
+ ).get(table_name, [])
414
+
415
+ @reflection.cache
416
+ def _get_schema_foreign_keys(self, connection, schema, **kw):
417
+ _, current_schema = self._current_database_schema(connection, **kw)
418
+ result = connection.execute(
419
+ text(
420
+ f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN SCHEMA {schema}"
421
+ )
422
+ )
423
+ foreign_key_map = {}
424
+ for row in result:
425
+ name = self.normalize_name(row._mapping["fk_name"])
426
+ if name not in foreign_key_map:
427
+ referred_schema = self.normalize_name(row._mapping["pk_schema_name"])
428
+ foreign_key_map[name] = {
429
+ "constrained_columns": [
430
+ self.normalize_name(row._mapping["fk_column_name"])
431
+ ],
432
+ # referred schema should be None in context where it doesn't need to be specified
433
+ # https://docs.sqlalchemy.org/en/14/core/reflection.html#reflection-schema-qualified-interaction
434
+ "referred_schema": (
435
+ referred_schema
436
+ if referred_schema
437
+ not in (self.default_schema_name, current_schema)
438
+ else None
439
+ ),
440
+ "referred_table": self.normalize_name(
441
+ row._mapping["pk_table_name"]
442
+ ),
443
+ "referred_columns": [
444
+ self.normalize_name(row._mapping["pk_column_name"])
445
+ ],
446
+ "name": name,
447
+ "table_name": self.normalize_name(row._mapping["fk_table_name"]),
448
+ }
449
+ options = {}
450
+ if self.normalize_name(row._mapping["delete_rule"]) != "NO ACTION":
451
+ options["ondelete"] = self.normalize_name(
452
+ row._mapping["delete_rule"]
453
+ )
454
+ if self.normalize_name(row._mapping["update_rule"]) != "NO ACTION":
455
+ options["onupdate"] = self.normalize_name(
456
+ row._mapping["update_rule"]
457
+ )
458
+ foreign_key_map[name]["options"] = options
459
+ else:
460
+ foreign_key_map[name]["constrained_columns"].append(
461
+ self.normalize_name(row._mapping["fk_column_name"])
462
+ )
463
+ foreign_key_map[name]["referred_columns"].append(
464
+ self.normalize_name(row._mapping["pk_column_name"])
465
+ )
466
+
467
+ ans = {}
468
+
469
+ for _, v in foreign_key_map.items():
470
+ if v["table_name"] not in ans:
471
+ ans[v["table_name"]] = []
472
+ ans[v["table_name"]].append(
473
+ {k2: v2 for k2, v2 in v.items() if k2 != "table_name"}
474
+ )
475
+ return ans
476
+
477
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
478
+ """
479
+ Gets all foreign keys for a table
480
+ """
481
+ schema = schema or self.default_schema_name
482
+ current_database, current_schema = self._current_database_schema(
483
+ connection, **kw
484
+ )
485
+ full_schema_name = self._denormalize_quote_join(
486
+ current_database, schema if schema else current_schema
487
+ )
488
+
489
+ foreign_key_map = self._get_schema_foreign_keys(
490
+ connection, self.denormalize_name(full_schema_name), **kw
491
+ )
492
+ return foreign_key_map.get(table_name, [])
493
+
494
+ def table_columns_as_dict(self, columns):
495
+ result = {}
496
+ for column in columns:
497
+ result[column["name"]] = column
498
+ return result
499
+
500
+ @reflection.cache
501
+ def _get_schema_columns(self, connection, schema, **kw):
502
+ """Get all columns in the schema, if we hit 'Information schema query returned too much data' problem return
503
+ None, as it is cacheable and is an unexpected return type for this function"""
504
+ ans = {}
505
+ current_database, _ = self._current_database_schema(connection, **kw)
506
+ full_schema_name = self._denormalize_quote_join(current_database, schema)
507
+ full_columns_descriptions = {}
508
+ try:
509
+ schema_primary_keys = self._get_schema_primary_keys(
510
+ connection, full_schema_name, **kw
511
+ )
512
+ schema_name = self.denormalize_name(schema)
513
+
514
+ iceberg_table_names = self.get_table_names_with_prefix(
515
+ connection,
516
+ schema=schema_name,
517
+ prefix=CustomTablePrefix.ICEBERG.name,
518
+ info_cache=kw.get("info_cache", None),
519
+ )
520
+
521
+ result = connection.execute(
522
+ text(
523
+ """
524
+ SELECT /* sqlalchemy:_get_schema_columns */
525
+ ic.table_name,
526
+ ic.column_name,
527
+ ic.data_type,
528
+ ic.character_maximum_length,
529
+ ic.numeric_precision,
530
+ ic.numeric_scale,
531
+ ic.is_nullable,
532
+ ic.column_default,
533
+ ic.is_identity,
534
+ ic.comment,
535
+ ic.identity_start,
536
+ ic.identity_increment
537
+ FROM information_schema.columns ic
538
+ WHERE ic.table_schema=:table_schema
539
+ ORDER BY ic.ordinal_position"""
540
+ ),
541
+ {"table_schema": schema_name},
542
+ )
543
+ except sa_exc.ProgrammingError as pe:
544
+ if pe.orig.errno == 90030:
545
+ # This means that there are too many tables in the schema, we need to go more granular
546
+ return None # None triggers _get_table_columns while staying cacheable
547
+ raise
548
+ for (
549
+ table_name,
550
+ column_name,
551
+ coltype,
552
+ character_maximum_length,
553
+ numeric_precision,
554
+ numeric_scale,
555
+ is_nullable,
556
+ column_default,
557
+ is_identity,
558
+ comment,
559
+ identity_start,
560
+ identity_increment,
561
+ ) in result:
562
+ table_name = self.normalize_name(table_name)
563
+ column_name = self.normalize_name(column_name)
564
+ if table_name not in ans:
565
+ ans[table_name] = list()
566
+ if column_name.startswith("sys_clustering_column"):
567
+ continue # ignoring clustering column
568
+ col_type = self.ischema_names.get(coltype, None)
569
+ col_type_kw = {}
570
+ if col_type is None:
571
+ col_type = NullType
572
+ else:
573
+ if issubclass(col_type, FLOAT):
574
+ col_type_kw["precision"] = numeric_precision
575
+ col_type_kw["decimal_return_scale"] = numeric_scale
576
+ elif issubclass(col_type, sqltypes.Numeric):
577
+ col_type_kw["precision"] = numeric_precision
578
+ col_type_kw["scale"] = numeric_scale
579
+ elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)):
580
+ col_type_kw["length"] = character_maximum_length
581
+ elif (
582
+ issubclass(col_type, StructuredType)
583
+ and table_name in iceberg_table_names
584
+ ):
585
+ if (schema_name, table_name) not in full_columns_descriptions:
586
+ full_columns_descriptions[(schema_name, table_name)] = (
587
+ self.table_columns_as_dict(
588
+ self._get_table_columns(
589
+ connection, table_name, schema_name
590
+ )
591
+ )
592
+ )
593
+
594
+ if (
595
+ (schema_name, table_name) in full_columns_descriptions
596
+ and column_name
597
+ in full_columns_descriptions[(schema_name, table_name)]
598
+ ):
599
+ ans[table_name].append(
600
+ full_columns_descriptions[(schema_name, table_name)][
601
+ column_name
602
+ ]
603
+ )
604
+ continue
605
+ else:
606
+ col_type = NullType
607
+ if col_type == NullType:
608
+ sa_util.warn(
609
+ f"Did not recognize type '{coltype}' of column '{column_name}'"
610
+ )
611
+
612
+ type_instance = col_type(**col_type_kw)
613
+
614
+ current_table_pks = schema_primary_keys.get(table_name)
615
+
616
+ ans[table_name].append(
617
+ {
618
+ "name": column_name,
619
+ "type": type_instance,
620
+ "nullable": is_nullable == "YES",
621
+ "default": column_default,
622
+ "autoincrement": is_identity == "YES",
623
+ "comment": comment,
624
+ "primary_key": (
625
+ (
626
+ column_name
627
+ in schema_primary_keys[table_name]["constrained_columns"]
628
+ )
629
+ if current_table_pks
630
+ else False
631
+ ),
632
+ }
633
+ )
634
+ if is_identity == "YES":
635
+ ans[table_name][-1]["identity"] = {
636
+ "start": identity_start,
637
+ "increment": identity_increment,
638
+ }
639
+ return ans
640
+
641
+ @reflection.cache
642
+ def _get_table_columns(self, connection, table_name, schema=None, **kw):
643
+ """Get all columns in a table in a schema"""
644
+ ans = []
645
+ current_database, default_schema = self._current_database_schema(
646
+ connection, **kw
647
+ )
648
+ schema = schema if schema else default_schema
649
+ table_schema = self.denormalize_name(schema)
650
+ table_name = self.denormalize_name(table_name)
651
+ result = connection.execute(
652
+ text(
653
+ "DESC /* sqlalchemy:_get_schema_columns */"
654
+ f" TABLE {table_schema}.{table_name} TYPE = COLUMNS"
655
+ )
656
+ )
657
+ for (
658
+ column_name,
659
+ coltype,
660
+ _kind,
661
+ is_nullable,
662
+ column_default,
663
+ primary_key,
664
+ _unique_key,
665
+ _check,
666
+ _expression,
667
+ comment,
668
+ _policy_name,
669
+ _privacy_domain,
670
+ _name_mapping,
671
+ ) in result:
672
+
673
+ column_name = self.normalize_name(column_name)
674
+ if column_name.startswith("sys_clustering_column"):
675
+ continue # ignoring clustering column
676
+ type_instance = parse_type(coltype)
677
+ if isinstance(type_instance, NullType):
678
+ sa_util.warn(
679
+ f"Did not recognize type '{coltype}' of column '{column_name}'"
680
+ )
681
+
682
+ identity = None
683
+ match = re.match(
684
+ r"IDENTITY START (?P<start>\d+) INCREMENT (?P<increment>\d+) (?P<order_type>ORDER|NOORDER)",
685
+ column_default if column_default else "",
686
+ )
687
+ if match:
688
+ identity = {
689
+ "start": int(match.group("start")),
690
+ "increment": int(match.group("increment")),
691
+ "order_type": match.group("order_type"),
692
+ }
693
+ is_identity = identity is not None
694
+
695
+ ans.append(
696
+ {
697
+ "name": column_name,
698
+ "type": type_instance,
699
+ "nullable": is_nullable == "Y",
700
+ "default": None if is_identity else column_default,
701
+ "autoincrement": is_identity,
702
+ "comment": comment if comment != "" else None,
703
+ "primary_key": primary_key == "Y",
704
+ }
705
+ )
706
+
707
+ if is_identity:
708
+ ans[-1]["identity"] = identity
709
+
710
+ # If we didn't find any columns for the table, the table doesn't exist.
711
+ if len(ans) == 0:
712
+ raise sa_exc.NoSuchTableError()
713
+ return ans
714
+
715
+ def get_columns(self, connection, table_name, schema=None, **kw):
716
+ """
717
+ Gets all column info given the table info
718
+ """
719
+ schema = schema or self.default_schema_name
720
+ if not schema:
721
+ _, schema = self._current_database_schema(connection, **kw)
722
+
723
+ schema_columns = self._get_schema_columns(connection, schema, **kw)
724
+ if schema_columns is None:
725
+ # Too many results, fall back to only query about single table
726
+ return self._get_table_columns(connection, table_name, schema, **kw)
727
+ normalized_table_name = self.normalize_name(table_name)
728
+ if normalized_table_name not in schema_columns:
729
+ raise sa_exc.NoSuchTableError()
730
+ return schema_columns[normalized_table_name]
731
+
732
+ def get_prefixes_from_data(self, name_to_index_map, row, **kw):
733
+ prefixes_found = []
734
+ for valid_prefix in CustomTablePrefix:
735
+ key = f"is_{valid_prefix.name.lower()}"
736
+ if key in name_to_index_map and row[name_to_index_map[key]] == "Y":
737
+ prefixes_found.append(valid_prefix.name)
738
+ return prefixes_found
739
+
740
+ @reflection.cache
741
+ def _get_schema_tables_info(self, connection, schema=None, **kw):
742
+ """
743
+ Retrieves information about all tables in the specified schema.
744
+ """
745
+
746
+ schema = schema or self.default_schema_name
747
+ result = connection.execute(
748
+ text(
749
+ f"SHOW /* sqlalchemy:get_schema_tables_info */ TABLES IN SCHEMA {self._denormalize_quote_join(schema)}"
750
+ )
751
+ )
752
+
753
+ name_to_index_map = self._map_name_to_idx(result)
754
+ tables = {}
755
+ for row in result.cursor.fetchall():
756
+ table_name = self.normalize_name(str(row[name_to_index_map["name"]]))
757
+ table_prefixes = self.get_prefixes_from_data(name_to_index_map, row)
758
+ tables[table_name] = {"prefixes": table_prefixes}
759
+
760
+ return tables
761
+
762
+ def get_table_names(self, connection, schema=None, **kw):
763
+ """
764
+ Gets all table names.
765
+ """
766
+ ret = self._get_schema_tables_info(
767
+ connection, schema, info_cache=kw.get("info_cache", None)
768
+ ).keys()
769
+ return list(ret)
770
+
771
+ @reflection.cache
772
+ def get_view_names(self, connection, schema=None, **kw):
773
+ """
774
+ Gets all view names
775
+ """
776
+ schema = schema or self.default_schema_name
777
+ if schema:
778
+ cursor = connection.execute(
779
+ text(
780
+ f"SHOW /* sqlalchemy:get_view_names */ VIEWS IN {self._denormalize_quote_join(schema)}"
781
+ )
782
+ )
783
+ else:
784
+ cursor = connection.execute(
785
+ text("SHOW /* sqlalchemy:get_view_names */ VIEWS")
786
+ )
787
+
788
+ return [self.normalize_name(row[1]) for row in cursor]
789
+
790
+ @reflection.cache
791
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
792
+ """
793
+ Gets the view definition
794
+ """
795
+ schema = schema or self.default_schema_name
796
+ if schema:
797
+ cursor = connection.execute(
798
+ text(
799
+ f"SHOW /* sqlalchemy:get_view_definition */ VIEWS \
800
+ LIKE '{self._denormalize_quote_join(view_name)}' IN {self._denormalize_quote_join(schema)}"
801
+ )
802
+ )
803
+ else:
804
+ cursor = connection.execute(
805
+ text(
806
+ f"SHOW /* sqlalchemy:get_view_definition */ VIEWS \
807
+ LIKE '{self._denormalize_quote_join(view_name)}'"
808
+ )
809
+ )
810
+
811
+ n2i = self.__class__._map_name_to_idx(cursor)
812
+ try:
813
+ ret = cursor.fetchone()
814
+ if ret:
815
+ return ret[n2i["text"]]
816
+ except Exception:
817
+ pass
818
+ return None
819
+
820
+ def get_temp_table_names(self, connection, schema=None, **kw):
821
+ schema = schema or self.default_schema_name
822
+ cursor = connection.execute(
823
+ text(
824
+ f"SHOW /* sqlalchemy:get_temp_table_names */ TABLES \
825
+ IN SCHEMA {self._denormalize_quote_join(schema)}"
826
+ )
827
+ )
828
+
829
+ ret = []
830
+ n2i = self.__class__._map_name_to_idx(cursor)
831
+ for row in cursor:
832
+ if row[n2i["kind"]] == "TEMPORARY":
833
+ ret.append(self.normalize_name(row[n2i["name"]]))
834
+
835
+ return ret
836
+
837
+ def get_schema_names(self, connection, **kw):
838
+ """
839
+ Gets all schema names.
840
+ """
841
+ cursor = connection.execute(
842
+ text("SHOW /* sqlalchemy:get_schema_names */ SCHEMAS")
843
+ )
844
+
845
+ return [self.normalize_name(row[1]) for row in cursor]
846
+
847
+ @reflection.cache
848
+ def get_sequence_names(self, connection, schema=None, **kw):
849
+ sql_command = "SHOW SEQUENCES {}".format(
850
+ f"IN SCHEMA {self.normalize_name(schema)}" if schema else ""
851
+ )
852
+ try:
853
+ cursor = connection.execute(text(sql_command))
854
+ return [self.normalize_name(row[0]) for row in cursor]
855
+ except sa_exc.ProgrammingError as pe:
856
+ if pe.orig.errno == 2003:
857
+ # Schema does not exist
858
+ return []
859
+
860
+ def _get_table_comment(self, connection, table_name, schema=None, **kw):
861
+ """
862
+ Returns comment of table in a dictionary as described by SQLAlchemy spec.
863
+ """
864
+ sql_command = (
865
+ "SHOW /* sqlalchemy:_get_table_comment */ "
866
+ "TABLES LIKE '{}'{}".format(
867
+ table_name,
868
+ f" IN SCHEMA {self.normalize_name(schema)}" if schema else "",
869
+ )
870
+ )
871
+ cursor = connection.execute(text(sql_command))
872
+ return cursor.fetchone()
873
+
874
+ def _get_view_comment(self, connection, table_name, schema=None, **kw):
875
+ """
876
+ Returns comment of view in a dictionary as described by SQLAlchemy spec.
877
+ """
878
+ sql_command = (
879
+ "SHOW /* sqlalchemy:_get_view_comment */ "
880
+ "VIEWS LIKE '{}'{}".format(
881
+ table_name,
882
+ f" IN SCHEMA {self.normalize_name(schema)}" if schema else "",
883
+ )
884
+ )
885
+ cursor = connection.execute(text(sql_command))
886
+ return cursor.fetchone()
887
+
888
+ def get_table_comment(self, connection, table_name, schema=None, **kw):
889
+ """
890
+ Returns comment associated with a table (or view) in a dictionary as
891
+ SQLAlchemy expects. Note that since SQLAlchemy may not (in fact,
892
+ typically does not) know if this is a table or a view, we have to
893
+ handle both cases here.
894
+ """
895
+ result = self._get_table_comment(connection, table_name, schema)
896
+ if result is None:
897
+ # the "table" being reflected is actually a view
898
+ result = self._get_view_comment(connection, table_name, schema)
899
+
900
+ return {
901
+ "text": (
902
+ result._mapping["comment"]
903
+ if result and result._mapping["comment"]
904
+ else None
905
+ )
906
+ }
907
+
908
+ def get_table_names_with_prefix(
909
+ self,
910
+ connection,
911
+ *,
912
+ schema,
913
+ prefix,
914
+ **kw,
915
+ ):
916
+ tables_data = self._get_schema_tables_info(connection, schema, **kw)
917
+ table_names = []
918
+ for table_name, tables_data_value in tables_data.items():
919
+ if prefix in tables_data_value["prefixes"]:
920
+ table_names.append(table_name)
921
+ return table_names
922
+
923
+ def get_multi_indexes(
924
+ self,
925
+ connection,
926
+ *,
927
+ schema: Optional[str] = None,
928
+ filter_names: Optional[Collection[str]] = None,
929
+ **kw,
930
+ ):
931
+ """
932
+ Gets the indexes definition
933
+ """
934
+ schema = schema or self.default_schema_name
935
+ hybrid_table_names = self.get_table_names_with_prefix(
936
+ connection,
937
+ schema=schema,
938
+ prefix=CustomTablePrefix.HYBRID.name,
939
+ info_cache=kw.get("info_cache", None),
940
+ )
941
+ if len(hybrid_table_names) == 0:
942
+ return []
943
+
944
+ result = connection.execute(
945
+ text(
946
+ f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}"
947
+ )
948
+ )
949
+
950
+ n2i = self._map_name_to_idx(result)
951
+ indexes = {}
952
+
953
+ for row in result.cursor.fetchall():
954
+ table_name = self.normalize_name(str(row[n2i["table"]]))
955
+ if (
956
+ row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY'
957
+ or table_name not in filter_names
958
+ or table_name not in hybrid_table_names
959
+ ):
960
+ continue
961
+ index = {
962
+ "name": row[n2i["name"]],
963
+ "unique": row[n2i["is_unique"]] == "Y",
964
+ "column_names": [
965
+ self.normalize_name(column)
966
+ for column in parse_index_columns(row[n2i["columns"]])
967
+ ],
968
+ "include_columns": [
969
+ self.normalize_name(column)
970
+ for column in parse_index_columns(row[n2i["included_columns"]])
971
+ ],
972
+ "dialect_options": {},
973
+ }
974
+
975
+ if (schema, table_name) in indexes:
976
+ indexes[(schema, table_name)] = indexes[(schema, table_name)].append(
977
+ index
978
+ )
979
+ else:
980
+ indexes[(schema, table_name)] = [index]
981
+
982
+ return list(indexes.items())
983
+
984
+ def _value_or_default(self, data, table, schema):
985
+ table = self.normalize_name(str(table))
986
+ dic_data = dict(data)
987
+ if (schema, table) in dic_data:
988
+ return dic_data[(schema, table)]
989
+ else:
990
+ return []
991
+
992
+ @reflection.cache
993
+ def get_indexes(self, connection, tablename, schema, **kw):
994
+ """
995
+ Gets the indexes definition
996
+ """
997
+ table_name = self.normalize_name(str(tablename))
998
+ data = self.get_multi_indexes(
999
+ connection=connection, schema=schema, filter_names=[table_name], **kw
1000
+ )
1001
+
1002
+ return self._value_or_default(data, table_name, schema)
1003
+
1004
+ def connect(self, *cargs, **cparams):
1005
+ return (
1006
+ super().connect(
1007
+ *cargs,
1008
+ **(
1009
+ _update_connection_application_name(**cparams)
1010
+ if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
1011
+ else cparams
1012
+ ),
1013
+ )
1014
+ if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
1015
+ else super().connect(*cargs, **cparams)
1016
+ )
1017
+
1018
+
1019
+ @sa_vnt.listens_for(Table, "before_create")
1020
+ def check_table(table, connection, _ddl_runner, **kw):
1021
+ from .sql.custom_schema.hybrid_table import HybridTable
1022
+
1023
+ if HybridTable.is_equal_type(table): # noqa
1024
+ return True
1025
+ if isinstance(_ddl_runner.dialect, SnowflakeDialect) and table.indexes:
1026
+ raise NotImplementedError("Only Snowflake Hybrid Tables supports indexes")
1027
+
1028
+
1029
+ dialect = SnowflakeDialect