snowflake-sqlalchemy 1.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1065 @@
1
+ #
2
+ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ import itertools
6
+ import operator
7
+ import re
8
+
9
+ from sqlalchemy import exc as sa_exc
10
+ from sqlalchemy import inspect, sql
11
+ from sqlalchemy import util as sa_util
12
+ from sqlalchemy.engine import default
13
+ from sqlalchemy.orm import context
14
+ from sqlalchemy.orm.context import _MapperEntity
15
+ from sqlalchemy.schema import Sequence, Table
16
+ from sqlalchemy.sql import compiler, expression
17
+ from sqlalchemy.sql.base import CompileState
18
+ from sqlalchemy.sql.elements import quoted_name
19
+ from sqlalchemy.sql.selectable import Lateral, SelectState
20
+ from sqlalchemy.util.compat import string_types
21
+
22
+ from .custom_commands import AWSBucket, AzureContainer, ExternalStage
23
+ from .util import (
24
+ _find_left_clause_to_join_from,
25
+ _set_connection_interpolate_empty_sequences,
26
+ _Snowflake_ORMJoin,
27
+ _Snowflake_Selectable_Join,
28
+ )
29
+
30
+ RESERVED_WORDS = frozenset(
31
+ [
32
+ "ALL", # ANSI Reserved words
33
+ "ALTER",
34
+ "AND",
35
+ "ANY",
36
+ "AS",
37
+ "BETWEEN",
38
+ "BY",
39
+ "CHECK",
40
+ "COLUMN",
41
+ "CONNECT",
42
+ "COPY",
43
+ "CREATE",
44
+ "CURRENT",
45
+ "DELETE",
46
+ "DISTINCT",
47
+ "DROP",
48
+ "ELSE",
49
+ "EXISTS",
50
+ "FOR",
51
+ "FROM",
52
+ "GRANT",
53
+ "GROUP",
54
+ "HAVING",
55
+ "IN",
56
+ "INSERT",
57
+ "INTERSECT",
58
+ "INTO",
59
+ "IS",
60
+ "LIKE",
61
+ "NOT",
62
+ "NULL",
63
+ "OF",
64
+ "ON",
65
+ "OR",
66
+ "ORDER",
67
+ "REVOKE",
68
+ "ROW",
69
+ "ROWS",
70
+ "SAMPLE",
71
+ "SELECT",
72
+ "SET",
73
+ "START",
74
+ "TABLE",
75
+ "THEN",
76
+ "TO",
77
+ "TRIGGER",
78
+ "UNION",
79
+ "UNIQUE",
80
+ "UPDATE",
81
+ "VALUES",
82
+ "WHENEVER",
83
+ "WHERE",
84
+ "WITH",
85
+ "REGEXP",
86
+ "RLIKE",
87
+ "SOME", # Snowflake Reserved words
88
+ "MINUS",
89
+ "INCREMENT", # Oracle reserved words
90
+ ]
91
+ )
92
+
93
+ # Snowflake DML:
94
+ # - UPDATE
95
+ # - INSERT
96
+ # - DELETE
97
+ # - MERGE
98
+ AUTOCOMMIT_REGEXP = re.compile(
99
+ r"\s*(?:UPDATE|INSERT|DELETE|MERGE|COPY)", re.I | re.UNICODE
100
+ )
101
+
102
+
103
+ """
104
+ Overwrite methods to handle Snowflake BCR change:
105
+ https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057
106
+ - _join_determine_implicit_left_side
107
+ - _join_left_to_right
108
+ """
109
+
110
+
111
+ # handle Snowflake BCR bcr-1057
112
+ @CompileState.plugin_for("default", "select")
113
+ class SnowflakeSelectState(SelectState):
114
+ def _setup_joins(self, args, raw_columns):
115
+ for right, onclause, left, flags in args:
116
+ isouter = flags["isouter"]
117
+ full = flags["full"]
118
+
119
+ if left is None:
120
+ (
121
+ left,
122
+ replace_from_obj_index,
123
+ ) = self._join_determine_implicit_left_side(
124
+ raw_columns, left, right, onclause
125
+ )
126
+ else:
127
+ (replace_from_obj_index) = self._join_place_explicit_left_side(left)
128
+
129
+ if replace_from_obj_index is not None:
130
+ # splice into an existing element in the
131
+ # self._from_obj list
132
+ left_clause = self.from_clauses[replace_from_obj_index]
133
+
134
+ self.from_clauses = (
135
+ self.from_clauses[:replace_from_obj_index]
136
+ + (
137
+ _Snowflake_Selectable_Join( # handle Snowflake BCR bcr-1057
138
+ left_clause,
139
+ right,
140
+ onclause,
141
+ isouter=isouter,
142
+ full=full,
143
+ ),
144
+ )
145
+ + self.from_clauses[replace_from_obj_index + 1 :]
146
+ )
147
+ else:
148
+ self.from_clauses = self.from_clauses + (
149
+ # handle Snowflake BCR bcr-1057
150
+ _Snowflake_Selectable_Join(
151
+ left, right, onclause, isouter=isouter, full=full
152
+ ),
153
+ )
154
+
155
+ @sa_util.preload_module("sqlalchemy.sql.util")
156
+ def _join_determine_implicit_left_side(self, raw_columns, left, right, onclause):
157
+ """When join conditions don't express the left side explicitly,
158
+ determine if an existing FROM or entity in this query
159
+ can serve as the left hand side.
160
+
161
+ """
162
+
163
+ replace_from_obj_index = None
164
+
165
+ from_clauses = self.from_clauses
166
+
167
+ if from_clauses:
168
+ # handle Snowflake BCR bcr-1057
169
+ indexes = _find_left_clause_to_join_from(from_clauses, right, onclause)
170
+
171
+ if len(indexes) == 1:
172
+ replace_from_obj_index = indexes[0]
173
+ left = from_clauses[replace_from_obj_index]
174
+ else:
175
+ potential = {}
176
+ statement = self.statement
177
+
178
+ for from_clause in itertools.chain(
179
+ itertools.chain.from_iterable(
180
+ [element._from_objects for element in raw_columns]
181
+ ),
182
+ itertools.chain.from_iterable(
183
+ [element._from_objects for element in statement._where_criteria]
184
+ ),
185
+ ):
186
+
187
+ potential[from_clause] = ()
188
+
189
+ all_clauses = list(potential.keys())
190
+ # handle Snowflake BCR bcr-1057
191
+ indexes = _find_left_clause_to_join_from(all_clauses, right, onclause)
192
+
193
+ if len(indexes) == 1:
194
+ left = all_clauses[indexes[0]]
195
+
196
+ if len(indexes) > 1:
197
+ raise sa_exc.InvalidRequestError(
198
+ "Can't determine which FROM clause to join "
199
+ "from, there are multiple FROMS which can "
200
+ "join to this entity. Please use the .select_from() "
201
+ "method to establish an explicit left side, as well as "
202
+ "providing an explicit ON clause if not present already to "
203
+ "help resolve the ambiguity."
204
+ )
205
+ elif not indexes:
206
+ raise sa_exc.InvalidRequestError(
207
+ "Don't know how to join to %r. "
208
+ "Please use the .select_from() "
209
+ "method to establish an explicit left side, as well as "
210
+ "providing an explicit ON clause if not present already to "
211
+ "help resolve the ambiguity." % (right,)
212
+ )
213
+ return left, replace_from_obj_index
214
+
215
+
216
+ # handle Snowflake BCR bcr-1057
217
+ @sql.base.CompileState.plugin_for("orm", "select")
218
+ class SnowflakeORMSelectCompileState(context.ORMSelectCompileState):
219
+ def _join_determine_implicit_left_side(
220
+ self, entities_collection, left, right, onclause
221
+ ):
222
+ """When join conditions don't express the left side explicitly,
223
+ determine if an existing FROM or entity in this query
224
+ can serve as the left hand side.
225
+
226
+ """
227
+
228
+ # when we are here, it means join() was called without an ORM-
229
+ # specific way of telling us what the "left" side is, e.g.:
230
+ #
231
+ # join(RightEntity)
232
+ #
233
+ # or
234
+ #
235
+ # join(RightEntity, RightEntity.foo == LeftEntity.bar)
236
+ #
237
+
238
+ r_info = inspect(right)
239
+
240
+ replace_from_obj_index = use_entity_index = None
241
+
242
+ if self.from_clauses:
243
+ # we have a list of FROMs already. So by definition this
244
+ # join has to connect to one of those FROMs.
245
+
246
+ # handle Snowflake BCR bcr-1057
247
+ indexes = _find_left_clause_to_join_from(
248
+ self.from_clauses, r_info.selectable, onclause
249
+ )
250
+
251
+ if len(indexes) == 1:
252
+ replace_from_obj_index = indexes[0]
253
+ left = self.from_clauses[replace_from_obj_index]
254
+ elif len(indexes) > 1:
255
+ raise sa_exc.InvalidRequestError(
256
+ "Can't determine which FROM clause to join "
257
+ "from, there are multiple FROMS which can "
258
+ "join to this entity. Please use the .select_from() "
259
+ "method to establish an explicit left side, as well as "
260
+ "providing an explicit ON clause if not present already "
261
+ "to help resolve the ambiguity."
262
+ )
263
+ else:
264
+ raise sa_exc.InvalidRequestError(
265
+ "Don't know how to join to %r. "
266
+ "Please use the .select_from() "
267
+ "method to establish an explicit left side, as well as "
268
+ "providing an explicit ON clause if not present already "
269
+ "to help resolve the ambiguity." % (right,)
270
+ )
271
+
272
+ elif entities_collection:
273
+ # we have no explicit FROMs, so the implicit left has to
274
+ # come from our list of entities.
275
+
276
+ potential = {}
277
+ for entity_index, ent in enumerate(entities_collection):
278
+ entity = ent.entity_zero_or_selectable
279
+ if entity is None:
280
+ continue
281
+ ent_info = inspect(entity)
282
+ if ent_info is r_info: # left and right are the same, skip
283
+ continue
284
+
285
+ # by using a dictionary with the selectables as keys this
286
+ # de-duplicates those selectables as occurs when the query is
287
+ # against a series of columns from the same selectable
288
+ if isinstance(ent, context._MapperEntity):
289
+ potential[ent.selectable] = (entity_index, entity)
290
+ else:
291
+ potential[ent_info.selectable] = (None, entity)
292
+
293
+ all_clauses = list(potential.keys())
294
+ # handle Snowflake BCR bcr-1057
295
+ indexes = _find_left_clause_to_join_from(
296
+ all_clauses, r_info.selectable, onclause
297
+ )
298
+
299
+ if len(indexes) == 1:
300
+ use_entity_index, left = potential[all_clauses[indexes[0]]]
301
+ elif len(indexes) > 1:
302
+ raise sa_exc.InvalidRequestError(
303
+ "Can't determine which FROM clause to join "
304
+ "from, there are multiple FROMS which can "
305
+ "join to this entity. Please use the .select_from() "
306
+ "method to establish an explicit left side, as well as "
307
+ "providing an explicit ON clause if not present already "
308
+ "to help resolve the ambiguity."
309
+ )
310
+ else:
311
+ raise sa_exc.InvalidRequestError(
312
+ "Don't know how to join to %r. "
313
+ "Please use the .select_from() "
314
+ "method to establish an explicit left side, as well as "
315
+ "providing an explicit ON clause if not present already "
316
+ "to help resolve the ambiguity." % (right,)
317
+ )
318
+ else:
319
+ raise sa_exc.InvalidRequestError(
320
+ "No entities to join from; please use "
321
+ "select_from() to establish the left "
322
+ "entity/selectable of this join"
323
+ )
324
+
325
+ return left, replace_from_obj_index, use_entity_index
326
+
327
+ def _join_left_to_right(
328
+ self,
329
+ entities_collection,
330
+ left,
331
+ right,
332
+ onclause,
333
+ prop,
334
+ create_aliases,
335
+ aliased_generation,
336
+ outerjoin,
337
+ full,
338
+ ):
339
+ """given raw "left", "right", "onclause" parameters consumed from
340
+ a particular key within _join(), add a real ORMJoin object to
341
+ our _from_obj list (or augment an existing one)
342
+
343
+ """
344
+
345
+ if left is None:
346
+ # left not given (e.g. no relationship object/name specified)
347
+ # figure out the best "left" side based on our existing froms /
348
+ # entities
349
+ assert prop is None
350
+ (
351
+ left,
352
+ replace_from_obj_index,
353
+ use_entity_index,
354
+ ) = self._join_determine_implicit_left_side(
355
+ entities_collection, left, right, onclause
356
+ )
357
+ else:
358
+ # left is given via a relationship/name, or as explicit left side.
359
+ # Determine where in our
360
+ # "froms" list it should be spliced/appended as well as what
361
+ # existing entity it corresponds to.
362
+ (
363
+ replace_from_obj_index,
364
+ use_entity_index,
365
+ ) = self._join_place_explicit_left_side(entities_collection, left)
366
+
367
+ if left is right and not create_aliases:
368
+ raise sa_exc.InvalidRequestError(
369
+ "Can't construct a join from %s to %s, they "
370
+ "are the same entity" % (left, right)
371
+ )
372
+
373
+ # the right side as given often needs to be adapted. additionally
374
+ # a lot of things can be wrong with it. handle all that and
375
+ # get back the new effective "right" side
376
+ r_info, right, onclause = self._join_check_and_adapt_right_side(
377
+ left, right, onclause, prop, create_aliases, aliased_generation
378
+ )
379
+
380
+ if not r_info.is_selectable:
381
+ extra_criteria = self._get_extra_criteria(r_info)
382
+ else:
383
+ extra_criteria = ()
384
+
385
+ if replace_from_obj_index is not None:
386
+ # splice into an existing element in the
387
+ # self._from_obj list
388
+ left_clause = self.from_clauses[replace_from_obj_index]
389
+
390
+ self.from_clauses = (
391
+ self.from_clauses[:replace_from_obj_index]
392
+ + [
393
+ _Snowflake_ORMJoin( # handle Snowflake BCR bcr-1057
394
+ left_clause,
395
+ right,
396
+ onclause,
397
+ isouter=outerjoin,
398
+ full=full,
399
+ _extra_criteria=extra_criteria,
400
+ )
401
+ ]
402
+ + self.from_clauses[replace_from_obj_index + 1 :]
403
+ )
404
+ else:
405
+ # add a new element to the self._from_obj list
406
+ if use_entity_index is not None:
407
+ # make use of _MapperEntity selectable, which is usually
408
+ # entity_zero.selectable, but if with_polymorphic() were used
409
+ # might be distinct
410
+ assert isinstance(entities_collection[use_entity_index], _MapperEntity)
411
+ left_clause = entities_collection[use_entity_index].selectable
412
+ else:
413
+ left_clause = left
414
+
415
+ self.from_clauses = self.from_clauses + [
416
+ _Snowflake_ORMJoin( # handle Snowflake BCR bcr-1057
417
+ left_clause,
418
+ r_info,
419
+ onclause,
420
+ isouter=outerjoin,
421
+ full=full,
422
+ _extra_criteria=extra_criteria,
423
+ )
424
+ ]
425
+
426
+
427
+ class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer):
428
+ reserved_words = {x.lower() for x in RESERVED_WORDS}
429
+
430
+ def __init__(self, dialect, **kw):
431
+ quote = '"'
432
+
433
+ super().__init__(dialect, initial_quote=quote, escape_quote=quote)
434
+
435
+ def _quote_free_identifiers(self, *ids):
436
+ """
437
+ Unilaterally identifier-quote any number of strings.
438
+ """
439
+ return tuple(self.quote(i) for i in ids if i is not None)
440
+
441
+ def quote_schema(self, schema, force=None):
442
+ """
443
+ Split schema by a dot and merge with required quotes
444
+ """
445
+ idents = self._split_schema_by_dot(schema)
446
+ return ".".join(self._quote_free_identifiers(*idents))
447
+
448
+ def format_label(self, label, name=None):
449
+ n = name or label.name
450
+ s = n.replace(self.escape_quote, "")
451
+
452
+ if not isinstance(n, quoted_name) or n.quote is None:
453
+ return self.quote(s)
454
+
455
+ return self.quote_identifier(s) if n.quote else s
456
+
457
+ def _split_schema_by_dot(self, schema):
458
+ ret = []
459
+ idx = 0
460
+ pre_idx = 0
461
+ in_quote = False
462
+ while idx < len(schema):
463
+ if not in_quote:
464
+ if schema[idx] == "." and pre_idx < idx:
465
+ ret.append(schema[pre_idx:idx])
466
+ pre_idx = idx + 1
467
+ elif schema[idx] == '"':
468
+ in_quote = True
469
+ pre_idx = idx + 1
470
+ else:
471
+ if schema[idx] == '"' and pre_idx < idx:
472
+ ret.append(schema[pre_idx:idx])
473
+ in_quote = False
474
+ pre_idx = idx + 1
475
+ idx += 1
476
+ if pre_idx < len(schema) and schema[pre_idx] == ".":
477
+ pre_idx += 1
478
+ if pre_idx < idx:
479
+ ret.append(schema[pre_idx:idx])
480
+
481
+ # convert the returning strings back to quoted_name types, and assign the original 'quote' attribute on it
482
+ quoted_ret = [
483
+ quoted_name(value, quote=getattr(schema, "quote", None)) for value in ret
484
+ ]
485
+
486
+ return quoted_ret
487
+
488
+
489
+ class SnowflakeCompiler(compiler.SQLCompiler):
490
+ def visit_sequence(self, sequence, **kw):
491
+ return self.dialect.identifier_preparer.format_sequence(sequence) + ".nextval"
492
+
493
+ def visit_now_func(self, now, **kw):
494
+ return "CURRENT_TIMESTAMP"
495
+
496
+ def visit_merge_into(self, merge_into, **kw):
497
+ clauses = " ".join(
498
+ clause._compiler_dispatch(self, **kw) for clause in merge_into.clauses
499
+ )
500
+ return (
501
+ f"MERGE INTO {merge_into.target} USING {merge_into.source} ON {merge_into.on}"
502
+ + (" " + clauses if clauses else "")
503
+ )
504
+
505
+ def visit_merge_into_clause(self, merge_into_clause, **kw):
506
+ case_predicate = (
507
+ f" AND {str(merge_into_clause.predicate._compiler_dispatch(self, **kw))}"
508
+ if merge_into_clause.predicate is not None
509
+ else ""
510
+ )
511
+ if merge_into_clause.command == "INSERT":
512
+ sets, sets_tos = zip(*merge_into_clause.set.items())
513
+ sets, sets_tos = list(sets), list(sets_tos)
514
+ if kw.get("deterministic", False):
515
+ sets, sets_tos = zip(
516
+ *sorted(merge_into_clause.set.items(), key=operator.itemgetter(0))
517
+ )
518
+ return "WHEN NOT MATCHED{} THEN {} ({}) VALUES ({})".format(
519
+ case_predicate,
520
+ merge_into_clause.command,
521
+ ", ".join(sets),
522
+ ", ".join(map(lambda e: e._compiler_dispatch(self, **kw), sets_tos)),
523
+ )
524
+ else:
525
+ set_list = list(merge_into_clause.set.items())
526
+ if kw.get("deterministic", False):
527
+ set_list.sort(key=operator.itemgetter(0))
528
+ sets = (
529
+ ", ".join(
530
+ [
531
+ f"{set[0]} = {set[1]._compiler_dispatch(self, **kw)}"
532
+ for set in set_list
533
+ ]
534
+ )
535
+ if merge_into_clause.set
536
+ else ""
537
+ )
538
+ return "WHEN MATCHED{} THEN {}{}".format(
539
+ case_predicate,
540
+ merge_into_clause.command,
541
+ " SET %s" % sets if merge_into_clause.set else "",
542
+ )
543
+
544
+ def visit_copy_into(self, copy_into, **kw):
545
+ if hasattr(copy_into, "formatter") and copy_into.formatter is not None:
546
+ formatter = copy_into.formatter._compiler_dispatch(self, **kw)
547
+ else:
548
+ formatter = ""
549
+ into = (
550
+ copy_into.into
551
+ if isinstance(copy_into.into, Table)
552
+ else copy_into.into._compiler_dispatch(self, **kw)
553
+ )
554
+ from_ = None
555
+ if isinstance(copy_into.from_, Table):
556
+ from_ = copy_into.from_
557
+ # this is intended to catch AWSBucket and AzureContainer
558
+ elif (
559
+ isinstance(copy_into.from_, AWSBucket)
560
+ or isinstance(copy_into.from_, AzureContainer)
561
+ or isinstance(copy_into.from_, ExternalStage)
562
+ ):
563
+ from_ = copy_into.from_._compiler_dispatch(self, **kw)
564
+ # everything else (selects, etc.)
565
+ else:
566
+ from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})"
567
+ credentials, encryption = "", ""
568
+ if isinstance(into, tuple):
569
+ into, credentials, encryption = into
570
+ elif isinstance(from_, tuple):
571
+ from_, credentials, encryption = from_
572
+ options_list = list(copy_into.copy_options.items())
573
+ if kw.get("deterministic", False):
574
+ options_list.sort(key=operator.itemgetter(0))
575
+ options = (
576
+ (
577
+ " "
578
+ + " ".join(
579
+ [
580
+ "{} = {}".format(
581
+ n,
582
+ (
583
+ v._compiler_dispatch(self, **kw)
584
+ if getattr(v, "compiler_dispatch", False)
585
+ else str(v)
586
+ ),
587
+ )
588
+ for n, v in options_list
589
+ ]
590
+ )
591
+ )
592
+ if copy_into.copy_options
593
+ else ""
594
+ )
595
+ if credentials:
596
+ options += f" {credentials}"
597
+ if encryption:
598
+ options += f" {encryption}"
599
+ return f"COPY INTO {into} FROM {from_} {formatter}{options}"
600
+
601
+ def visit_copy_formatter(self, formatter, **kw):
602
+ options_list = list(formatter.options.items())
603
+ if kw.get("deterministic", False):
604
+ options_list.sort(key=operator.itemgetter(0))
605
+ if "format_name" in formatter.options:
606
+ return f"FILE_FORMAT=(format_name = {formatter.options['format_name']})"
607
+ return "FILE_FORMAT=(TYPE={}{})".format(
608
+ formatter.file_format,
609
+ (
610
+ " "
611
+ + " ".join(
612
+ [
613
+ "{}={}".format(
614
+ name,
615
+ (
616
+ value._compiler_dispatch(self, **kw)
617
+ if hasattr(value, "_compiler_dispatch")
618
+ else formatter.value_repr(name, value)
619
+ ),
620
+ )
621
+ for name, value in options_list
622
+ ]
623
+ )
624
+ if formatter.options
625
+ else ""
626
+ ),
627
+ )
628
+
629
+ def visit_aws_bucket(self, aws_bucket, **kw):
630
+ credentials_list = list(aws_bucket.credentials_used.items())
631
+ if kw.get("deterministic", False):
632
+ credentials_list.sort(key=operator.itemgetter(0))
633
+ credentials = "CREDENTIALS=({})".format(
634
+ " ".join(f"{n}='{v}'" for n, v in credentials_list)
635
+ )
636
+ encryption_list = list(aws_bucket.encryption_used.items())
637
+ if kw.get("deterministic", False):
638
+ encryption_list.sort(key=operator.itemgetter(0))
639
+ encryption = "ENCRYPTION=({})".format(
640
+ " ".join(
641
+ ("{}='{}'" if isinstance(v, string_types) else "{}={}").format(n, v)
642
+ for n, v in encryption_list
643
+ )
644
+ )
645
+ uri = "'s3://{}{}'".format(
646
+ aws_bucket.bucket, f"/{aws_bucket.path}" if aws_bucket.path else ""
647
+ )
648
+ return (
649
+ uri,
650
+ credentials if aws_bucket.credentials_used else "",
651
+ encryption if aws_bucket.encryption_used else "",
652
+ )
653
+
654
+ def visit_azure_container(self, azure_container, **kw):
655
+ credentials_list = list(azure_container.credentials_used.items())
656
+ if kw.get("deterministic", False):
657
+ credentials_list.sort(key=operator.itemgetter(0))
658
+ credentials = "CREDENTIALS=({})".format(
659
+ " ".join(f"{n}='{v}'" for n, v in credentials_list)
660
+ )
661
+ encryption_list = list(azure_container.encryption_used.items())
662
+ if kw.get("deterministic", False):
663
+ encryption_list.sort(key=operator.itemgetter(0))
664
+ encryption = "ENCRYPTION=({})".format(
665
+ " ".join(
666
+ f"{n}='{v}'" if isinstance(v, string_types) else f"{n}={v}"
667
+ for n, v in encryption_list
668
+ )
669
+ )
670
+ uri = "'azure://{}.blob.core.windows.net/{}{}'".format(
671
+ azure_container.account,
672
+ azure_container.container,
673
+ f"/{azure_container.path}" if azure_container.path else "",
674
+ )
675
+ return (
676
+ uri,
677
+ credentials if azure_container.credentials_used else "",
678
+ encryption if azure_container.encryption_used else "",
679
+ )
680
+
681
+ def visit_external_stage(self, external_stage, **kw):
682
+ if external_stage.file_format is None:
683
+ return (
684
+ f"@{external_stage.namespace}{external_stage.name}{external_stage.path}"
685
+ )
686
+ return f"@{external_stage.namespace}{external_stage.name}{external_stage.path} (file_format => {external_stage.file_format})"
687
+
688
+ def delete_extra_from_clause(
689
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
690
+ ):
691
+ return "USING " + ", ".join(
692
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
693
+ for t in extra_froms
694
+ )
695
+
696
+ def update_from_clause(
697
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
698
+ ):
699
+ return "FROM " + ", ".join(
700
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
701
+ for t in extra_froms
702
+ )
703
+
704
+ def _get_regexp_args(self, binary, kw):
705
+ string = self.process(binary.left, **kw)
706
+ pattern = self.process(binary.right, **kw)
707
+ flags = binary.modifiers["flags"]
708
+ if flags is not None:
709
+ flags = self.process(flags, **kw)
710
+ return string, pattern, flags
711
+
712
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
713
+ string, pattern, flags = self._get_regexp_args(binary, kw)
714
+ if flags is None:
715
+ return f"REGEXP_LIKE({string}, {pattern})"
716
+ else:
717
+ return f"REGEXP_LIKE({string}, {pattern}, {flags})"
718
+
719
+ def visit_regexp_replace_op_binary(self, binary, operator, **kw):
720
+ string, pattern, flags = self._get_regexp_args(binary, kw)
721
+ try:
722
+ replacement = self.process(binary.modifiers["replacement"], **kw)
723
+ except KeyError:
724
+ # in sqlalchemy 1.4.49, the internal structure of the expression is changed
725
+ # that binary.modifiers doesn't have "replacement":
726
+ # https://docs.sqlalchemy.org/en/20/changelog/changelog_14.html#change-1.4.49
727
+ return f"REGEXP_REPLACE({string}, {pattern}{'' if flags is None else f', {flags}'})"
728
+
729
+ if flags is None:
730
+ return f"REGEXP_REPLACE({string}, {pattern}, {replacement})"
731
+ else:
732
+ return f"REGEXP_REPLACE({string}, {pattern}, {replacement}, {flags})"
733
+
734
+ def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
735
+ return f"NOT {self.visit_regexp_match_op_binary(binary, operator, **kw)}"
736
+
737
+ def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
738
+ if from_linter:
739
+ from_linter.edges.update(
740
+ itertools.product(join.left._from_objects, join.right._from_objects)
741
+ )
742
+
743
+ if join.full:
744
+ join_type = " FULL OUTER JOIN "
745
+ elif join.isouter:
746
+ join_type = " LEFT OUTER JOIN "
747
+ else:
748
+ join_type = " JOIN "
749
+
750
+ join_statement = (
751
+ join.left._compiler_dispatch(
752
+ self, asfrom=True, from_linter=from_linter, **kwargs
753
+ )
754
+ + join_type
755
+ + join.right._compiler_dispatch(
756
+ self, asfrom=True, from_linter=from_linter, **kwargs
757
+ )
758
+ )
759
+
760
+ if join.onclause is None and isinstance(join.right, Lateral):
761
+ # in snowflake, onclause is not accepted for lateral due to BCR change:
762
+ # https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057
763
+ # sqlalchemy only allows join with on condition.
764
+ # to adapt to snowflake syntax change,
765
+ # we make the change such that when oncaluse is None and the right part is
766
+ # Lateral, we do not append the on condition
767
+ return join_statement
768
+
769
+ return (
770
+ join_statement
771
+ + " ON "
772
+ # TODO: likely need asfrom=True here?
773
+ + join.onclause._compiler_dispatch(self, from_linter=from_linter, **kwargs)
774
+ )
775
+
776
+ def render_literal_value(self, value, type_):
777
+ # escape backslash
778
+ return super().render_literal_value(value, type_).replace("\\", "\\\\")
779
+
780
+
781
+ class SnowflakeExecutionContext(default.DefaultExecutionContext):
782
+ INSERT_SQL_RE = re.compile(r"^insert\s+into", flags=re.IGNORECASE)
783
+
784
+ def fire_sequence(self, seq, type_):
785
+ return self._execute_scalar(
786
+ f"SELECT {self.identifier_preparer.format_sequence(seq)}.nextval",
787
+ type_,
788
+ )
789
+
790
+ def should_autocommit_text(self, statement):
791
+ return AUTOCOMMIT_REGEXP.match(statement)
792
+
793
+ @sa_util.memoized_property
794
+ def should_autocommit(self):
795
+ autocommit = self.execution_options.get(
796
+ "autocommit",
797
+ not self.compiled
798
+ and self.statement
799
+ and expression.PARSE_AUTOCOMMIT
800
+ or False,
801
+ )
802
+
803
+ if autocommit is expression.PARSE_AUTOCOMMIT:
804
+ return self.should_autocommit_text(self.unicode_statement)
805
+ else:
806
+ return autocommit and not self.isddl
807
+
808
+ def pre_exec(self):
809
+ if self.compiled and self.identifier_preparer._double_percents:
810
+ # for compiled statements, percent is doubled for escape, we turn on _interpolate_empty_sequences
811
+ _set_connection_interpolate_empty_sequences(self._dbapi_connection, True)
812
+
813
+ # if the statement is executemany insert, setting _interpolate_empty_sequences to True is not enough,
814
+ # because executemany pre-processes the param binding and then pass None params to execute so
815
+ # _interpolate_empty_sequences condition not getting met for the command.
816
+ # Therefore, we manually revert the escape percent in the command here
817
+ if self.executemany and self.INSERT_SQL_RE.match(self.statement):
818
+ self.statement = self.statement.replace("%%", "%")
819
+ else:
820
+ # for other cases, do no interpolate empty sequences as "%" is not double escaped
821
+ _set_connection_interpolate_empty_sequences(self._dbapi_connection, False)
822
+
823
+ def post_exec(self):
824
+ if self.compiled and self.identifier_preparer._double_percents:
825
+ # for compiled statements, percent is doubled for escapeafter execution
826
+ # we reset _interpolate_empty_sequences to false which is turned on in pre_exec
827
+ _set_connection_interpolate_empty_sequences(self._dbapi_connection, False)
828
+
829
+ @property
830
+ def rowcount(self):
831
+ return self.cursor.rowcount
832
+
833
+
834
+ class SnowflakeDDLCompiler(compiler.DDLCompiler):
835
+ def denormalize_column_name(self, name):
836
+ if name is None:
837
+ return None
838
+ elif name.lower() == name and not self.preparer._requires_quotes(name.lower()):
839
+ # no quote as case insensitive
840
+ return name
841
+ return self.preparer.quote(name)
842
+
843
+ def get_column_specification(self, column, **kwargs):
844
+ """
845
+ Gets Column specifications
846
+ """
847
+ colspec = [
848
+ self.preparer.format_column(column),
849
+ self.dialect.type_compiler.process(column.type, type_expression=column),
850
+ ]
851
+
852
+ has_identity = (
853
+ column.identity is not None and self.dialect.supports_identity_columns
854
+ )
855
+
856
+ if not column.nullable:
857
+ colspec.append("NOT NULL")
858
+
859
+ default = self.get_column_default_string(column)
860
+ if default is not None:
861
+ colspec.append("DEFAULT " + default)
862
+
863
+ # TODO: This makes the first INTEGER column AUTOINCREMENT.
864
+ # But the column is not really considered so unless
865
+ # postfetch_lastrowid is enabled. But it is very unlikely to happen...
866
+ if (
867
+ column.table is not None
868
+ and column is column.table._autoincrement_column
869
+ and column.server_default is None
870
+ ):
871
+ if isinstance(column.default, Sequence):
872
+ colspec.append(
873
+ f"DEFAULT {self.dialect.identifier_preparer.format_sequence(column.default)}.nextval"
874
+ )
875
+ else:
876
+ colspec.append("AUTOINCREMENT")
877
+
878
+ if has_identity:
879
+ colspec.append(self.process(column.identity))
880
+
881
+ return " ".join(colspec)
882
+
883
+ def post_create_table(self, table):
884
+ """
885
+ Handles snowflake-specific ``CREATE TABLE ... CLUSTER BY`` syntax.
886
+
887
+ Users can specify the `clusterby` property per table
888
+ using the dialect specific syntax.
889
+ For example, to specify a cluster by key you apply the following:
890
+
891
+ >>> import sqlalchemy as sa
892
+ >>> from sqlalchemy.schema import CreateTable
893
+ >>> engine = sa.create_engine('snowflake://om1')
894
+ >>> metadata = sa.MetaData()
895
+ >>> user = sa.Table(
896
+ ... 'user',
897
+ ... metadata,
898
+ ... sa.Column('id', sa.Integer, primary_key=True),
899
+ ... sa.Column('name', sa.String),
900
+ ... snowflake_clusterby=['id', 'name']
901
+ ... )
902
+ >>> print(CreateTable(user).compile(engine))
903
+ <BLANKLINE>
904
+ CREATE TABLE "user" (
905
+ id INTEGER NOT NULL AUTOINCREMENT,
906
+ name VARCHAR,
907
+ PRIMARY KEY (id)
908
+ ) CLUSTER BY (id, name)
909
+ <BLANKLINE>
910
+ <BLANKLINE>
911
+ """
912
+ text = ""
913
+ info = table.dialect_options["snowflake"]
914
+ cluster = info.get("clusterby")
915
+ if cluster:
916
+ text += " CLUSTER BY ({})".format(
917
+ ", ".join(self.denormalize_column_name(key) for key in cluster)
918
+ )
919
+ return text
920
+
921
+ def visit_create_stage(self, create_stage, **kw):
922
+ """
923
+ This visitor will create the SQL representation for a CREATE STAGE command.
924
+ """
925
+ return "CREATE {or_replace}{temporary}STAGE {}{} URL={}".format(
926
+ create_stage.stage.namespace,
927
+ create_stage.stage.name,
928
+ repr(create_stage.container),
929
+ or_replace="OR REPLACE " if create_stage.replace_if_exists else "",
930
+ temporary="TEMPORARY " if create_stage.temporary else "",
931
+ )
932
+
933
+ def visit_create_file_format(self, file_format, **kw):
934
+ """
935
+ This visitor will create the SQL representation for a CREATE FILE FORMAT
936
+ command.
937
+ """
938
+ return "CREATE {}FILE FORMAT {} TYPE='{}' {}".format(
939
+ "OR REPLACE " if file_format.replace_if_exists else "",
940
+ file_format.format_name,
941
+ file_format.formatter.file_format,
942
+ " ".join(
943
+ [
944
+ f"{name} = {file_format.formatter.value_repr(name, value)}"
945
+ for name, value in file_format.formatter.options.items()
946
+ ]
947
+ ),
948
+ )
949
+
950
+ def visit_drop_table_comment(self, drop, **kw):
951
+ """Snowflake does not support setting table comments as NULL.
952
+
953
+ Reflection has to account for this and convert any empty comments to NULL.
954
+ """
955
+ table_name = self.preparer.format_table(drop.element)
956
+ return f"COMMENT ON TABLE {table_name} IS ''"
957
+
958
+ def visit_drop_column_comment(self, drop, **kw):
959
+ """Snowflake does not support directly setting column comments as NULL.
960
+
961
+ Instead we are forced to use the ALTER COLUMN ... UNSET COMMENT instead.
962
+ """
963
+ return "ALTER TABLE {} ALTER COLUMN {} UNSET COMMENT".format(
964
+ self.preparer.format_table(drop.element.table),
965
+ self.preparer.format_column(drop.element),
966
+ )
967
+
968
+ def visit_identity_column(self, identity, **kw):
969
+ text = " IDENTITY"
970
+ if identity.start is not None or identity.increment is not None:
971
+ start = 1 if identity.start is None else identity.start
972
+ increment = 1 if identity.increment is None else identity.increment
973
+ text += f"({start},{increment})"
974
+ return text
975
+
976
+ def get_identity_options(self, identity_options):
977
+ text = []
978
+ if identity_options.increment is not None:
979
+ text.append(f"INCREMENT BY {identity_options.increment:d}")
980
+ if identity_options.start is not None:
981
+ text.append(f"START WITH {identity_options.start:d}")
982
+ if identity_options.minvalue is not None:
983
+ text.append(f"MINVALUE {identity_options.minvalue:d}")
984
+ if identity_options.maxvalue is not None:
985
+ text.append(f"MAXVALUE {identity_options.maxvalue:d}")
986
+ if identity_options.nominvalue is not None:
987
+ text.append("NO MINVALUE")
988
+ if identity_options.nomaxvalue is not None:
989
+ text.append("NO MAXVALUE")
990
+ if identity_options.cache is not None:
991
+ text.append(f"CACHE {identity_options.cache:d}")
992
+ if identity_options.cycle is not None:
993
+ text.append("CYCLE" if identity_options.cycle else "NO CYCLE")
994
+ if identity_options.order is not None:
995
+ text.append("ORDER" if identity_options.order else "NOORDER")
996
+
997
+ return " ".join(text)
998
+
999
+
1000
+ class SnowflakeTypeCompiler(compiler.GenericTypeCompiler):
1001
+ def visit_BYTEINT(self, type_, **kw):
1002
+ return "BYTEINT"
1003
+
1004
+ def visit_CHARACTER(self, type_, **kw):
1005
+ return "CHARACTER"
1006
+
1007
+ def visit_DEC(self, type_, **kw):
1008
+ return "DEC"
1009
+
1010
+ def visit_DOUBLE(self, type_, **kw):
1011
+ return "DOUBLE"
1012
+
1013
+ def visit_FIXED(self, type_, **kw):
1014
+ return "FIXED"
1015
+
1016
+ def visit_INT(self, type_, **kw):
1017
+ return "INT"
1018
+
1019
+ def visit_NUMBER(self, type_, **kw):
1020
+ return "NUMBER"
1021
+
1022
+ def visit_STRING(self, type_, **kw):
1023
+ return "STRING"
1024
+
1025
+ def visit_TINYINT(self, type_, **kw):
1026
+ return "TINYINT"
1027
+
1028
+ def visit_VARIANT(self, type_, **kw):
1029
+ return "VARIANT"
1030
+
1031
+ def visit_ARRAY(self, type_, **kw):
1032
+ return "ARRAY"
1033
+
1034
+ def visit_OBJECT(self, type_, **kw):
1035
+ return "OBJECT"
1036
+
1037
+ def visit_BLOB(self, type_, **kw):
1038
+ return "BINARY"
1039
+
1040
+ def visit_datetime(self, type_, **kw):
1041
+ return "datetime"
1042
+
1043
+ def visit_DATETIME(self, type_, **kw):
1044
+ return "DATETIME"
1045
+
1046
+ def visit_TIMESTAMP_NTZ(self, type_, **kw):
1047
+ return "TIMESTAMP_NTZ"
1048
+
1049
+ def visit_TIMESTAMP_TZ(self, type_, **kw):
1050
+ return "TIMESTAMP_TZ"
1051
+
1052
+ def visit_TIMESTAMP_LTZ(self, type_, **kw):
1053
+ return "TIMESTAMP_LTZ"
1054
+
1055
+ def visit_TIMESTAMP(self, type_, **kw):
1056
+ return "TIMESTAMP"
1057
+
1058
+ def visit_GEOGRAPHY(self, type_, **kw):
1059
+ return "GEOGRAPHY"
1060
+
1061
+ def visit_GEOMETRY(self, type_, **kw):
1062
+ return "GEOMETRY"
1063
+
1064
+
1065
+ construct_arguments = [(Table, {"clusterby": None})]