snowflake-sqlalchemy 1.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,621 @@
1
+ #
2
+ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ from collections.abc import Sequence
6
+ from typing import List
7
+
8
+ from sqlalchemy import false, true
9
+ from sqlalchemy.sql.ddl import DDLElement
10
+ from sqlalchemy.sql.dml import UpdateBase
11
+ from sqlalchemy.sql.elements import ClauseElement
12
+ from sqlalchemy.sql.roles import FromClauseRole
13
+ from sqlalchemy.util.compat import string_types
14
+
15
+ NoneType = type(None)
16
+
17
+
18
+ def translate_bool(bln):
19
+ if bln:
20
+ return true()
21
+ return false()
22
+
23
+
24
+ class MergeInto(UpdateBase):
25
+ __visit_name__ = "merge_into"
26
+ _bind = None
27
+
28
+ def __init__(self, target, source, on):
29
+ self.target = target
30
+ self.source = source
31
+ self.on = on
32
+ self.clauses = []
33
+
34
+ class clause(ClauseElement):
35
+ __visit_name__ = "merge_into_clause"
36
+
37
+ def __init__(self, command):
38
+ self.set = {}
39
+ self.predicate = None
40
+ self.command = command
41
+
42
+ def __repr__(self):
43
+ case_predicate = (
44
+ f" AND {str(self.predicate)}" if self.predicate is not None else ""
45
+ )
46
+ if self.command == "INSERT":
47
+ sets, sets_tos = zip(*self.set.items())
48
+ return "WHEN NOT MATCHED{} THEN {} ({}) VALUES ({})".format(
49
+ case_predicate,
50
+ self.command,
51
+ ", ".join(sets),
52
+ ", ".join(map(str, sets_tos)),
53
+ )
54
+ else:
55
+ # WHEN MATCHED clause
56
+ sets = (
57
+ ", ".join([f"{set[0]} = {set[1]}" for set in self.set.items()])
58
+ if self.set
59
+ else ""
60
+ )
61
+ return "WHEN MATCHED{} THEN {}{}".format(
62
+ case_predicate,
63
+ self.command,
64
+ f" SET {str(sets)}" if self.set else "",
65
+ )
66
+
67
+ def values(self, **kwargs):
68
+ self.set = kwargs
69
+ return self
70
+
71
+ def where(self, expr):
72
+ self.predicate = expr
73
+ return self
74
+
75
+ def __repr__(self):
76
+ clauses = " ".join([repr(clause) for clause in self.clauses])
77
+ return f"MERGE INTO {self.target} USING {self.source} ON {self.on}" + (
78
+ f" {clauses}" if clauses else ""
79
+ )
80
+
81
+ def when_matched_then_update(self):
82
+ clause = self.clause("UPDATE")
83
+ self.clauses.append(clause)
84
+ return clause
85
+
86
+ def when_matched_then_delete(self):
87
+ clause = self.clause("DELETE")
88
+ self.clauses.append(clause)
89
+ return clause
90
+
91
+ def when_not_matched_then_insert(self):
92
+ clause = self.clause("INSERT")
93
+ self.clauses.append(clause)
94
+ return clause
95
+
96
+
97
+ class FilesOption:
98
+ """
99
+ Class to represent FILES option for the snowflake COPY INTO statement
100
+ """
101
+
102
+ def __init__(self, file_names: List[str]):
103
+ self.file_names = file_names
104
+
105
+ def __str__(self):
106
+ the_files = ["'" + f.replace("'", "\\'") + "'" for f in self.file_names]
107
+ return f"({','.join(the_files)})"
108
+
109
+
110
+ class CopyInto(UpdateBase):
111
+ """Copy Into Command base class, for documentation see:
112
+ https://docs.snowflake.net/manuals/sql-reference/sql/copy-into-location.html"""
113
+
114
+ __visit_name__ = "copy_into"
115
+ _bind = None
116
+
117
+ def __init__(self, from_, into, formatter=None):
118
+ self.from_ = from_
119
+ self.into = into
120
+ self.formatter = formatter
121
+ self.copy_options = {}
122
+
123
+ def __repr__(self):
124
+ """
125
+ repr for debugging / logging purposes only. For compilation logic, see
126
+ the corresponding visitor in base.py
127
+ """
128
+ return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})"
129
+
130
+ def bind(self):
131
+ return None
132
+
133
+ def force(self, force):
134
+ if not isinstance(force, bool):
135
+ raise TypeError("Parameter force should be a boolean value")
136
+ self.copy_options.update({"FORCE": translate_bool(force)})
137
+ return self
138
+
139
+ def single(self, single_file):
140
+ if not isinstance(single_file, bool):
141
+ raise TypeError("Parameter single_file should be a boolean value")
142
+ self.copy_options.update({"SINGLE": translate_bool(single_file)})
143
+ return self
144
+
145
+ def maxfilesize(self, max_size):
146
+ if not isinstance(max_size, int):
147
+ raise TypeError("Parameter max_size should be an integer value")
148
+ self.copy_options.update({"MAX_FILE_SIZE": max_size})
149
+ return self
150
+
151
+ def files(self, file_names):
152
+ self.copy_options.update({"FILES": FilesOption(file_names)})
153
+ return self
154
+
155
+ def pattern(self, pattern):
156
+ self.copy_options.update({"PATTERN": pattern})
157
+ return self
158
+
159
+
160
+ class CopyFormatter(ClauseElement):
161
+ """
162
+ Base class for Formatter specifications inside a COPY INTO statement. May also
163
+ be used to create a named format.
164
+ """
165
+
166
+ __visit_name__ = "copy_formatter"
167
+
168
+ def __init__(self, format_name=None):
169
+ self.options = dict()
170
+ if format_name:
171
+ self.options["format_name"] = format_name
172
+
173
+ def __repr__(self):
174
+ """
175
+ repr for debugging / logging purposes only. For compilation logic, see
176
+ the corresponding visitor in base.py
177
+ """
178
+ return f"FILE_FORMAT=({self.options})"
179
+
180
+ @staticmethod
181
+ def value_repr(name, value):
182
+ """
183
+ Make a SQL-suitable representation of "value". This is called from
184
+ the corresponding visitor function (base.py/visit_copy_formatter())
185
+ - in case of a format name: return it without quotes
186
+ - in case of a string: enclose in quotes: "value"
187
+ - in case of a tuple of length 1: enclose the only element in brackets: (value)
188
+ Standard stringification of Python would append a trailing comma: (value,)
189
+ which is not correct in SQL
190
+ - otherwise: just convert to str as is: value
191
+ """
192
+ if name == "format_name":
193
+ return value
194
+ elif isinstance(value, str):
195
+ return f"'{value}'"
196
+ elif isinstance(value, tuple) and len(value) == 1:
197
+ return f"('{value[0]}')"
198
+ else:
199
+ return str(value)
200
+
201
+
202
+ class CSVFormatter(CopyFormatter):
203
+ file_format = "csv"
204
+
205
+ def compression(self, comp_type):
206
+ """String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm."""
207
+ if isinstance(comp_type, string_types):
208
+ comp_type = comp_type.lower()
209
+ _available_options = [
210
+ "auto",
211
+ "gzip",
212
+ "bz2",
213
+ "brotli",
214
+ "zstd",
215
+ "deflate",
216
+ "raw_deflate",
217
+ None,
218
+ ]
219
+ if comp_type not in _available_options:
220
+ raise TypeError(f"Compression type should be one of : {_available_options}")
221
+ self.options["COMPRESSION"] = comp_type
222
+ return self
223
+
224
+ def _check_delimiter(self, delimiter, delimiter_txt):
225
+ """
226
+ Check if a delimiter is either a string of length 1 or an integer. In case of
227
+ a string delimiter, take into account that the actual string may be longer,
228
+ but still evaluate to a single character (like "\\n" or r"\n"
229
+ """
230
+ if isinstance(delimiter, NoneType):
231
+ return
232
+ if isinstance(delimiter, string_types):
233
+ delimiter_processed = delimiter.encode().decode("unicode_escape")
234
+ if len(delimiter_processed) == 1:
235
+ return
236
+ if isinstance(delimiter, int):
237
+ return
238
+ raise TypeError(
239
+ f"{delimiter_txt} should be a single character, that is either a string, or a number"
240
+ )
241
+
242
+ def record_delimiter(self, deli_type):
243
+ """Character that separates records in an unloaded file."""
244
+ self._check_delimiter(deli_type, "Record delimiter")
245
+ if isinstance(deli_type, int):
246
+ self.options["RECORD_DELIMITER"] = hex(deli_type)
247
+ else:
248
+ self.options["RECORD_DELIMITER"] = deli_type
249
+ return self
250
+
251
+ def field_delimiter(self, deli_type):
252
+ """Character that separates fields in an unloaded file."""
253
+ self._check_delimiter(deli_type, "Field delimiter")
254
+ if isinstance(deli_type, int):
255
+ self.options["FIELD_DELIMITER"] = hex(deli_type)
256
+ else:
257
+ self.options["FIELD_DELIMITER"] = deli_type
258
+ return self
259
+
260
+ def file_extension(self, ext):
261
+ """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is
262
+ responsible for specifying a valid file extension that can be read by the desired software or service.
263
+ """
264
+ if not isinstance(ext, (NoneType, string_types)):
265
+ raise TypeError("File extension should be a string")
266
+ self.options["FILE_EXTENSION"] = ext
267
+ return self
268
+
269
+ def date_format(self, dt_frmt):
270
+ """String that defines the format of date values in the unloaded data files."""
271
+ if not isinstance(dt_frmt, string_types):
272
+ raise TypeError("Date format should be a string")
273
+ self.options["DATE_FORMAT"] = dt_frmt
274
+ return self
275
+
276
+ def time_format(self, tm_frmt):
277
+ """String that defines the format of time values in the unloaded data files."""
278
+ if not isinstance(tm_frmt, string_types):
279
+ raise TypeError("Time format should be a string")
280
+ self.options["TIME_FORMAT"] = tm_frmt
281
+ return self
282
+
283
+ def timestamp_format(self, tmstmp_frmt):
284
+ """String that defines the format of timestamp values in the unloaded data files."""
285
+ if not isinstance(tmstmp_frmt, string_types):
286
+ raise TypeError("Timestamp format should be a string")
287
+ self.options["TIMESTAMP_FORMAT"] = tmstmp_frmt
288
+ return self
289
+
290
+ def binary_format(self, bin_fmt):
291
+ """Character used as the escape character for any field values. The option can be used when unloading data
292
+ from binary columns in a table."""
293
+ if isinstance(bin_fmt, string_types):
294
+ bin_fmt = bin_fmt.lower()
295
+ _available_options = ["hex", "base64", "utf8"]
296
+ if bin_fmt not in _available_options:
297
+ raise TypeError(f"Binary format should be one of : {_available_options}")
298
+ self.options["BINARY_FORMAT"] = bin_fmt
299
+ return self
300
+
301
+ def escape(self, esc):
302
+ """Character used as the escape character for any field values."""
303
+ self._check_delimiter(esc, "Escape")
304
+ if isinstance(esc, int):
305
+ self.options["ESCAPE"] = hex(esc)
306
+ else:
307
+ self.options["ESCAPE"] = esc
308
+ return self
309
+
310
+ def escape_unenclosed_field(self, esc):
311
+ """Single character string used as the escape character for unenclosed field values only."""
312
+ self._check_delimiter(esc, "Escape unenclosed field")
313
+ if isinstance(esc, int):
314
+ self.options["ESCAPE_UNENCLOSED_FIELD"] = hex(esc)
315
+ else:
316
+ self.options["ESCAPE_UNENCLOSED_FIELD"] = esc
317
+ return self
318
+
319
+ def field_optionally_enclosed_by(self, enc):
320
+ """Character used to enclose strings. Either None, ', or \"."""
321
+ _available_options = [None, "'", '"']
322
+ if enc not in _available_options:
323
+ raise TypeError(f"Enclosing string should be one of : {_available_options}")
324
+ self.options["FIELD_OPTIONALLY_ENCLOSED_BY"] = enc
325
+ return self
326
+
327
+ def null_if(self, null_value):
328
+ """Copying into a table these strings will be replaced by a NULL, while copying out of Snowflake will replace
329
+ NULL values with the first string"""
330
+ if not isinstance(null_value, Sequence):
331
+ raise TypeError("Parameter null_value should be an iterable")
332
+ self.options["NULL_IF"] = tuple(null_value)
333
+ return self
334
+
335
+ def skip_header(self, skip_header):
336
+ """
337
+ Number of header rows to be skipped at the beginning of the file
338
+ """
339
+ if not isinstance(skip_header, int):
340
+ raise TypeError("skip_header should be an int")
341
+ self.options["SKIP_HEADER"] = skip_header
342
+ return self
343
+
344
+ def trim_space(self, trim_space):
345
+ """
346
+ Remove leading or trailing white spaces
347
+ """
348
+ if not isinstance(trim_space, bool):
349
+ raise TypeError("trim_space should be a bool")
350
+ self.options["TRIM_SPACE"] = trim_space
351
+ return self
352
+
353
+ def error_on_column_count_mismatch(self, error_on_col_count_mismatch):
354
+ """
355
+ Generate a parsing error if the number of delimited columns (i.e. fields) in
356
+ an input data file does not match the number of columns in the corresponding table.
357
+ """
358
+ if not isinstance(error_on_col_count_mismatch, bool):
359
+ raise TypeError("skip_header should be a bool")
360
+ self.options["ERROR_ON_COLUMN_COUNT_MISMATCH"] = error_on_col_count_mismatch
361
+ return self
362
+
363
+
364
+ class JSONFormatter(CopyFormatter):
365
+ """Format specific functions"""
366
+
367
+ file_format = "json"
368
+
369
+ def compression(self, comp_type):
370
+ """String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm."""
371
+ if isinstance(comp_type, string_types):
372
+ comp_type = comp_type.lower()
373
+ _available_options = [
374
+ "auto",
375
+ "gzip",
376
+ "bz2",
377
+ "brotli",
378
+ "zstd",
379
+ "deflate",
380
+ "raw_deflate",
381
+ None,
382
+ ]
383
+ if comp_type not in _available_options:
384
+ raise TypeError(f"Compression type should be one of : {_available_options}")
385
+ self.options["COMPRESSION"] = comp_type
386
+ return self
387
+
388
+ def file_extension(self, ext):
389
+ """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is
390
+ responsible for specifying a valid file extension that can be read by the desired software or service.
391
+ """
392
+ if not isinstance(ext, (NoneType, string_types)):
393
+ raise TypeError("File extension should be a string")
394
+ self.options["FILE_EXTENSION"] = ext
395
+ return self
396
+
397
+
398
+ class PARQUETFormatter(CopyFormatter):
399
+ """Format specific functions"""
400
+
401
+ file_format = "parquet"
402
+
403
+ def snappy_compression(self, comp):
404
+ """Enable, or disable snappy compression"""
405
+ if not isinstance(comp, bool):
406
+ raise TypeError("Comp should be a Boolean value")
407
+ self.options["SNAPPY_COMPRESSION"] = translate_bool(comp)
408
+ return self
409
+
410
+ def compression(self, comp):
411
+ """
412
+ Set compression type
413
+ """
414
+ if not isinstance(comp, str):
415
+ raise TypeError("Comp should be a str value")
416
+ self.options["COMPRESSION"] = comp
417
+ return self
418
+
419
+ def binary_as_text(self, value):
420
+ """Enable, or disable binary as text"""
421
+ if not isinstance(value, bool):
422
+ raise TypeError("binary_as_text should be a Boolean value")
423
+ self.options["BINARY_AS_TEXT"] = translate_bool(value)
424
+ return self
425
+
426
+
427
+ class ExternalStage(ClauseElement, FromClauseRole):
428
+ """External Stage descriptor"""
429
+
430
+ __visit_name__ = "external_stage"
431
+ _hide_froms = ()
432
+
433
+ @staticmethod
434
+ def prepare_namespace(namespace):
435
+ return f"{namespace}." if not namespace.endswith(".") else namespace
436
+
437
+ @staticmethod
438
+ def prepare_path(path):
439
+ return f"/{path}" if not path.startswith("/") else path
440
+
441
+ def __init__(self, name, path=None, namespace=None, file_format=None):
442
+ self.name = name
443
+ self.path = self.prepare_path(path) if path else ""
444
+ self.namespace = self.prepare_namespace(namespace) if namespace else ""
445
+ self.file_format = file_format
446
+
447
+ def __repr__(self):
448
+ return f"@{self.namespace}{self.name}{self.path} ({self.file_format})"
449
+
450
+ @classmethod
451
+ def from_parent_stage(cls, parent_stage, path, file_format=None):
452
+ """
453
+ Extend an existing parent stage (with or without path) with an
454
+ additional sub-path
455
+ """
456
+ return cls(
457
+ parent_stage.name,
458
+ f"{parent_stage.path}/{path}",
459
+ parent_stage.namespace,
460
+ file_format,
461
+ )
462
+
463
+
464
+ class CreateFileFormat(DDLElement):
465
+ """
466
+ Encapsulates a CREATE FILE FORMAT statement; using a format description (as in
467
+ a COPY INTO statement) and a format name.
468
+ """
469
+
470
+ __visit_name__ = "create_file_format"
471
+
472
+ def __init__(self, format_name, formatter, replace_if_exists=False):
473
+ super().__init__()
474
+ self.format_name = format_name
475
+ self.formatter = formatter
476
+ self.replace_if_exists = replace_if_exists
477
+
478
+
479
+ class CreateStage(DDLElement):
480
+ """
481
+ Encapsulates a CREATE STAGE statement, using a container (physical base for the
482
+ stage) and the actual ExternalStage object.
483
+ """
484
+
485
+ __visit_name__ = "create_stage"
486
+
487
+ def __init__(self, container, stage, replace_if_exists=False, *, temporary=False):
488
+ super().__init__()
489
+ self.container = container
490
+ self.temporary = temporary
491
+ self.stage = stage
492
+ self.replace_if_exists = replace_if_exists
493
+
494
+
495
+ class AWSBucket(ClauseElement):
496
+ """AWS S3 bucket descriptor"""
497
+
498
+ __visit_name__ = "aws_bucket"
499
+
500
+ def __init__(self, bucket, path=None):
501
+ self.bucket = bucket
502
+ self.path = path
503
+ self.encryption_used = {}
504
+ self.credentials_used = {}
505
+
506
+ @classmethod
507
+ def from_uri(cls, uri):
508
+ if uri[0:5] != "s3://":
509
+ raise ValueError(f"Invalid AWS bucket URI: {uri}")
510
+ b = uri[5:].split("/", 1)
511
+ if len(b) == 1:
512
+ bucket, path = b[0], None
513
+ else:
514
+ bucket, path = b
515
+ return cls(bucket, path)
516
+
517
+ def __repr__(self):
518
+ credentials = "CREDENTIALS=({})".format(
519
+ " ".join(f"{n}='{v}'" for n, v in self.credentials_used.items())
520
+ )
521
+ encryption = "ENCRYPTION=({})".format(
522
+ " ".join(
523
+ f"{n}='{v}'" if isinstance(v, string_types) else f"{n}={v}"
524
+ for n, v in self.encryption_used.items()
525
+ )
526
+ )
527
+ uri = "'s3://{}{}'".format(self.bucket, f"/{self.path}" if self.path else "")
528
+ return "{}{}{}".format(
529
+ uri,
530
+ f" {credentials}" if self.credentials_used else "",
531
+ f" {encryption}" if self.encryption_used else "",
532
+ )
533
+
534
+ def credentials(
535
+ self, aws_role=None, aws_key_id=None, aws_secret_key=None, aws_token=None
536
+ ):
537
+ if aws_role is None and (aws_key_id is None and aws_secret_key is None):
538
+ raise ValueError(
539
+ "Either 'aws_role', or aws_key_id and aws_secret_key has to be supplied"
540
+ )
541
+ if aws_role:
542
+ self.credentials_used = {"AWS_ROLE": aws_role}
543
+ else:
544
+ self.credentials_used = {
545
+ "AWS_SECRET_KEY": aws_secret_key,
546
+ "AWS_KEY_ID": aws_key_id,
547
+ }
548
+ if aws_token:
549
+ self.credentials_used["AWS_TOKEN"] = aws_token
550
+ return self
551
+
552
+ def encryption_aws_cse(self, master_key):
553
+ self.encryption_used = {"TYPE": "AWS_CSE", "MASTER_KEY": master_key}
554
+ return self
555
+
556
+ def encryption_aws_sse_s3(self):
557
+ self.encryption_used = {"TYPE": "AWS_SSE_S3"}
558
+ return self
559
+
560
+ def encryption_aws_sse_kms(self, kms_key_id=None):
561
+ self.encryption_used = {"TYPE": "AWS_SSE_KMS"}
562
+ if kms_key_id:
563
+ self.encryption_used["KMS_KEY_ID"] = kms_key_id
564
+ return self
565
+
566
+
567
+ class AzureContainer(ClauseElement):
568
+ """Microsoft Azure Container descriptor"""
569
+
570
+ __visit_name__ = "azure_container"
571
+
572
+ def __init__(self, account, container, path=None):
573
+ self.account = account
574
+ self.container = container
575
+ self.path = path
576
+ self.encryption_used = {}
577
+ self.credentials_used = {}
578
+
579
+ @classmethod
580
+ def from_uri(cls, uri):
581
+ if uri[0:8] != "azure://":
582
+ raise ValueError(f"Invalid Azure Container URI: {uri}")
583
+ account, uri = uri[8:].split(".", 1)
584
+ if uri[0:22] != "blob.core.windows.net/":
585
+ raise ValueError(f"Invalid Azure Container URI: {uri}")
586
+ b = uri[22:].split("/", 1)
587
+ if len(b) == 1:
588
+ container, path = b[0], None
589
+ else:
590
+ container, path = b
591
+ return cls(account, container, path)
592
+
593
+ def __repr__(self):
594
+ credentials = "CREDENTIALS=({})".format(
595
+ " ".join(f"{n}='{v}'" for n, v in self.credentials_used.items())
596
+ )
597
+ encryption = "ENCRYPTION=({})".format(
598
+ " ".join(
599
+ f"{n}='{v}'" if isinstance(v, string_types) else f"{n}={v}"
600
+ for n, v in self.encryption_used.items()
601
+ )
602
+ )
603
+ uri = "'azure://{}.blob.core.windows.net/{}{}'".format(
604
+ self.account, self.container, f"/{self.path}" if self.path else ""
605
+ )
606
+ return "{}{}{}".format(
607
+ uri,
608
+ f" {credentials}" if self.credentials_used else "",
609
+ f" {encryption}" if self.encryption_used else "",
610
+ )
611
+
612
+ def credentials(self, azure_sas_token):
613
+ self.credentials_used = {"AZURE_SAS_TOKEN": azure_sas_token}
614
+ return self
615
+
616
+ def encryption_azure_cse(self, master_key):
617
+ self.encryption_used = {"TYPE": "AZURE_CSE", "MASTER_KEY": master_key}
618
+ return self
619
+
620
+
621
+ CopyIntoStorage = CopyInto