snowflake-sqlalchemy 1.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/sqlalchemy/__init__.py +116 -0
- snowflake/sqlalchemy/_constants.py +12 -0
- snowflake/sqlalchemy/base.py +1065 -0
- snowflake/sqlalchemy/custom_commands.py +621 -0
- snowflake/sqlalchemy/custom_types.py +105 -0
- snowflake/sqlalchemy/provision.py +12 -0
- snowflake/sqlalchemy/requirements.py +297 -0
- snowflake/sqlalchemy/snowdialect.py +911 -0
- snowflake/sqlalchemy/util.py +336 -0
- snowflake/sqlalchemy/version.py +6 -0
- snowflake_sqlalchemy-1.5.2.dist-info/METADATA +503 -0
- snowflake_sqlalchemy-1.5.2.dist-info/RECORD +15 -0
- snowflake_sqlalchemy-1.5.2.dist-info/WHEEL +4 -0
- snowflake_sqlalchemy-1.5.2.dist-info/entry_points.txt +2 -0
- snowflake_sqlalchemy-1.5.2.dist-info/licenses/LICENSE.txt +202 -0
|
@@ -0,0 +1,621 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import false, true
|
|
9
|
+
from sqlalchemy.sql.ddl import DDLElement
|
|
10
|
+
from sqlalchemy.sql.dml import UpdateBase
|
|
11
|
+
from sqlalchemy.sql.elements import ClauseElement
|
|
12
|
+
from sqlalchemy.sql.roles import FromClauseRole
|
|
13
|
+
from sqlalchemy.util.compat import string_types
|
|
14
|
+
|
|
15
|
+
NoneType = type(None)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def translate_bool(bln):
|
|
19
|
+
if bln:
|
|
20
|
+
return true()
|
|
21
|
+
return false()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MergeInto(UpdateBase):
|
|
25
|
+
__visit_name__ = "merge_into"
|
|
26
|
+
_bind = None
|
|
27
|
+
|
|
28
|
+
def __init__(self, target, source, on):
|
|
29
|
+
self.target = target
|
|
30
|
+
self.source = source
|
|
31
|
+
self.on = on
|
|
32
|
+
self.clauses = []
|
|
33
|
+
|
|
34
|
+
class clause(ClauseElement):
|
|
35
|
+
__visit_name__ = "merge_into_clause"
|
|
36
|
+
|
|
37
|
+
def __init__(self, command):
|
|
38
|
+
self.set = {}
|
|
39
|
+
self.predicate = None
|
|
40
|
+
self.command = command
|
|
41
|
+
|
|
42
|
+
def __repr__(self):
|
|
43
|
+
case_predicate = (
|
|
44
|
+
f" AND {str(self.predicate)}" if self.predicate is not None else ""
|
|
45
|
+
)
|
|
46
|
+
if self.command == "INSERT":
|
|
47
|
+
sets, sets_tos = zip(*self.set.items())
|
|
48
|
+
return "WHEN NOT MATCHED{} THEN {} ({}) VALUES ({})".format(
|
|
49
|
+
case_predicate,
|
|
50
|
+
self.command,
|
|
51
|
+
", ".join(sets),
|
|
52
|
+
", ".join(map(str, sets_tos)),
|
|
53
|
+
)
|
|
54
|
+
else:
|
|
55
|
+
# WHEN MATCHED clause
|
|
56
|
+
sets = (
|
|
57
|
+
", ".join([f"{set[0]} = {set[1]}" for set in self.set.items()])
|
|
58
|
+
if self.set
|
|
59
|
+
else ""
|
|
60
|
+
)
|
|
61
|
+
return "WHEN MATCHED{} THEN {}{}".format(
|
|
62
|
+
case_predicate,
|
|
63
|
+
self.command,
|
|
64
|
+
f" SET {str(sets)}" if self.set else "",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def values(self, **kwargs):
|
|
68
|
+
self.set = kwargs
|
|
69
|
+
return self
|
|
70
|
+
|
|
71
|
+
def where(self, expr):
|
|
72
|
+
self.predicate = expr
|
|
73
|
+
return self
|
|
74
|
+
|
|
75
|
+
def __repr__(self):
|
|
76
|
+
clauses = " ".join([repr(clause) for clause in self.clauses])
|
|
77
|
+
return f"MERGE INTO {self.target} USING {self.source} ON {self.on}" + (
|
|
78
|
+
f" {clauses}" if clauses else ""
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def when_matched_then_update(self):
|
|
82
|
+
clause = self.clause("UPDATE")
|
|
83
|
+
self.clauses.append(clause)
|
|
84
|
+
return clause
|
|
85
|
+
|
|
86
|
+
def when_matched_then_delete(self):
|
|
87
|
+
clause = self.clause("DELETE")
|
|
88
|
+
self.clauses.append(clause)
|
|
89
|
+
return clause
|
|
90
|
+
|
|
91
|
+
def when_not_matched_then_insert(self):
|
|
92
|
+
clause = self.clause("INSERT")
|
|
93
|
+
self.clauses.append(clause)
|
|
94
|
+
return clause
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class FilesOption:
|
|
98
|
+
"""
|
|
99
|
+
Class to represent FILES option for the snowflake COPY INTO statement
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(self, file_names: List[str]):
|
|
103
|
+
self.file_names = file_names
|
|
104
|
+
|
|
105
|
+
def __str__(self):
|
|
106
|
+
the_files = ["'" + f.replace("'", "\\'") + "'" for f in self.file_names]
|
|
107
|
+
return f"({','.join(the_files)})"
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class CopyInto(UpdateBase):
|
|
111
|
+
"""Copy Into Command base class, for documentation see:
|
|
112
|
+
https://docs.snowflake.net/manuals/sql-reference/sql/copy-into-location.html"""
|
|
113
|
+
|
|
114
|
+
__visit_name__ = "copy_into"
|
|
115
|
+
_bind = None
|
|
116
|
+
|
|
117
|
+
def __init__(self, from_, into, formatter=None):
|
|
118
|
+
self.from_ = from_
|
|
119
|
+
self.into = into
|
|
120
|
+
self.formatter = formatter
|
|
121
|
+
self.copy_options = {}
|
|
122
|
+
|
|
123
|
+
def __repr__(self):
|
|
124
|
+
"""
|
|
125
|
+
repr for debugging / logging purposes only. For compilation logic, see
|
|
126
|
+
the corresponding visitor in base.py
|
|
127
|
+
"""
|
|
128
|
+
return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})"
|
|
129
|
+
|
|
130
|
+
def bind(self):
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
def force(self, force):
|
|
134
|
+
if not isinstance(force, bool):
|
|
135
|
+
raise TypeError("Parameter force should be a boolean value")
|
|
136
|
+
self.copy_options.update({"FORCE": translate_bool(force)})
|
|
137
|
+
return self
|
|
138
|
+
|
|
139
|
+
def single(self, single_file):
|
|
140
|
+
if not isinstance(single_file, bool):
|
|
141
|
+
raise TypeError("Parameter single_file should be a boolean value")
|
|
142
|
+
self.copy_options.update({"SINGLE": translate_bool(single_file)})
|
|
143
|
+
return self
|
|
144
|
+
|
|
145
|
+
def maxfilesize(self, max_size):
|
|
146
|
+
if not isinstance(max_size, int):
|
|
147
|
+
raise TypeError("Parameter max_size should be an integer value")
|
|
148
|
+
self.copy_options.update({"MAX_FILE_SIZE": max_size})
|
|
149
|
+
return self
|
|
150
|
+
|
|
151
|
+
def files(self, file_names):
|
|
152
|
+
self.copy_options.update({"FILES": FilesOption(file_names)})
|
|
153
|
+
return self
|
|
154
|
+
|
|
155
|
+
def pattern(self, pattern):
|
|
156
|
+
self.copy_options.update({"PATTERN": pattern})
|
|
157
|
+
return self
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class CopyFormatter(ClauseElement):
|
|
161
|
+
"""
|
|
162
|
+
Base class for Formatter specifications inside a COPY INTO statement. May also
|
|
163
|
+
be used to create a named format.
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
__visit_name__ = "copy_formatter"
|
|
167
|
+
|
|
168
|
+
def __init__(self, format_name=None):
|
|
169
|
+
self.options = dict()
|
|
170
|
+
if format_name:
|
|
171
|
+
self.options["format_name"] = format_name
|
|
172
|
+
|
|
173
|
+
def __repr__(self):
|
|
174
|
+
"""
|
|
175
|
+
repr for debugging / logging purposes only. For compilation logic, see
|
|
176
|
+
the corresponding visitor in base.py
|
|
177
|
+
"""
|
|
178
|
+
return f"FILE_FORMAT=({self.options})"
|
|
179
|
+
|
|
180
|
+
@staticmethod
|
|
181
|
+
def value_repr(name, value):
|
|
182
|
+
"""
|
|
183
|
+
Make a SQL-suitable representation of "value". This is called from
|
|
184
|
+
the corresponding visitor function (base.py/visit_copy_formatter())
|
|
185
|
+
- in case of a format name: return it without quotes
|
|
186
|
+
- in case of a string: enclose in quotes: "value"
|
|
187
|
+
- in case of a tuple of length 1: enclose the only element in brackets: (value)
|
|
188
|
+
Standard stringification of Python would append a trailing comma: (value,)
|
|
189
|
+
which is not correct in SQL
|
|
190
|
+
- otherwise: just convert to str as is: value
|
|
191
|
+
"""
|
|
192
|
+
if name == "format_name":
|
|
193
|
+
return value
|
|
194
|
+
elif isinstance(value, str):
|
|
195
|
+
return f"'{value}'"
|
|
196
|
+
elif isinstance(value, tuple) and len(value) == 1:
|
|
197
|
+
return f"('{value[0]}')"
|
|
198
|
+
else:
|
|
199
|
+
return str(value)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class CSVFormatter(CopyFormatter):
|
|
203
|
+
file_format = "csv"
|
|
204
|
+
|
|
205
|
+
def compression(self, comp_type):
|
|
206
|
+
"""String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm."""
|
|
207
|
+
if isinstance(comp_type, string_types):
|
|
208
|
+
comp_type = comp_type.lower()
|
|
209
|
+
_available_options = [
|
|
210
|
+
"auto",
|
|
211
|
+
"gzip",
|
|
212
|
+
"bz2",
|
|
213
|
+
"brotli",
|
|
214
|
+
"zstd",
|
|
215
|
+
"deflate",
|
|
216
|
+
"raw_deflate",
|
|
217
|
+
None,
|
|
218
|
+
]
|
|
219
|
+
if comp_type not in _available_options:
|
|
220
|
+
raise TypeError(f"Compression type should be one of : {_available_options}")
|
|
221
|
+
self.options["COMPRESSION"] = comp_type
|
|
222
|
+
return self
|
|
223
|
+
|
|
224
|
+
def _check_delimiter(self, delimiter, delimiter_txt):
|
|
225
|
+
"""
|
|
226
|
+
Check if a delimiter is either a string of length 1 or an integer. In case of
|
|
227
|
+
a string delimiter, take into account that the actual string may be longer,
|
|
228
|
+
but still evaluate to a single character (like "\\n" or r"\n"
|
|
229
|
+
"""
|
|
230
|
+
if isinstance(delimiter, NoneType):
|
|
231
|
+
return
|
|
232
|
+
if isinstance(delimiter, string_types):
|
|
233
|
+
delimiter_processed = delimiter.encode().decode("unicode_escape")
|
|
234
|
+
if len(delimiter_processed) == 1:
|
|
235
|
+
return
|
|
236
|
+
if isinstance(delimiter, int):
|
|
237
|
+
return
|
|
238
|
+
raise TypeError(
|
|
239
|
+
f"{delimiter_txt} should be a single character, that is either a string, or a number"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
def record_delimiter(self, deli_type):
|
|
243
|
+
"""Character that separates records in an unloaded file."""
|
|
244
|
+
self._check_delimiter(deli_type, "Record delimiter")
|
|
245
|
+
if isinstance(deli_type, int):
|
|
246
|
+
self.options["RECORD_DELIMITER"] = hex(deli_type)
|
|
247
|
+
else:
|
|
248
|
+
self.options["RECORD_DELIMITER"] = deli_type
|
|
249
|
+
return self
|
|
250
|
+
|
|
251
|
+
def field_delimiter(self, deli_type):
|
|
252
|
+
"""Character that separates fields in an unloaded file."""
|
|
253
|
+
self._check_delimiter(deli_type, "Field delimiter")
|
|
254
|
+
if isinstance(deli_type, int):
|
|
255
|
+
self.options["FIELD_DELIMITER"] = hex(deli_type)
|
|
256
|
+
else:
|
|
257
|
+
self.options["FIELD_DELIMITER"] = deli_type
|
|
258
|
+
return self
|
|
259
|
+
|
|
260
|
+
def file_extension(self, ext):
|
|
261
|
+
"""String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is
|
|
262
|
+
responsible for specifying a valid file extension that can be read by the desired software or service.
|
|
263
|
+
"""
|
|
264
|
+
if not isinstance(ext, (NoneType, string_types)):
|
|
265
|
+
raise TypeError("File extension should be a string")
|
|
266
|
+
self.options["FILE_EXTENSION"] = ext
|
|
267
|
+
return self
|
|
268
|
+
|
|
269
|
+
def date_format(self, dt_frmt):
|
|
270
|
+
"""String that defines the format of date values in the unloaded data files."""
|
|
271
|
+
if not isinstance(dt_frmt, string_types):
|
|
272
|
+
raise TypeError("Date format should be a string")
|
|
273
|
+
self.options["DATE_FORMAT"] = dt_frmt
|
|
274
|
+
return self
|
|
275
|
+
|
|
276
|
+
def time_format(self, tm_frmt):
|
|
277
|
+
"""String that defines the format of time values in the unloaded data files."""
|
|
278
|
+
if not isinstance(tm_frmt, string_types):
|
|
279
|
+
raise TypeError("Time format should be a string")
|
|
280
|
+
self.options["TIME_FORMAT"] = tm_frmt
|
|
281
|
+
return self
|
|
282
|
+
|
|
283
|
+
def timestamp_format(self, tmstmp_frmt):
|
|
284
|
+
"""String that defines the format of timestamp values in the unloaded data files."""
|
|
285
|
+
if not isinstance(tmstmp_frmt, string_types):
|
|
286
|
+
raise TypeError("Timestamp format should be a string")
|
|
287
|
+
self.options["TIMESTAMP_FORMAT"] = tmstmp_frmt
|
|
288
|
+
return self
|
|
289
|
+
|
|
290
|
+
def binary_format(self, bin_fmt):
|
|
291
|
+
"""Character used as the escape character for any field values. The option can be used when unloading data
|
|
292
|
+
from binary columns in a table."""
|
|
293
|
+
if isinstance(bin_fmt, string_types):
|
|
294
|
+
bin_fmt = bin_fmt.lower()
|
|
295
|
+
_available_options = ["hex", "base64", "utf8"]
|
|
296
|
+
if bin_fmt not in _available_options:
|
|
297
|
+
raise TypeError(f"Binary format should be one of : {_available_options}")
|
|
298
|
+
self.options["BINARY_FORMAT"] = bin_fmt
|
|
299
|
+
return self
|
|
300
|
+
|
|
301
|
+
def escape(self, esc):
|
|
302
|
+
"""Character used as the escape character for any field values."""
|
|
303
|
+
self._check_delimiter(esc, "Escape")
|
|
304
|
+
if isinstance(esc, int):
|
|
305
|
+
self.options["ESCAPE"] = hex(esc)
|
|
306
|
+
else:
|
|
307
|
+
self.options["ESCAPE"] = esc
|
|
308
|
+
return self
|
|
309
|
+
|
|
310
|
+
def escape_unenclosed_field(self, esc):
|
|
311
|
+
"""Single character string used as the escape character for unenclosed field values only."""
|
|
312
|
+
self._check_delimiter(esc, "Escape unenclosed field")
|
|
313
|
+
if isinstance(esc, int):
|
|
314
|
+
self.options["ESCAPE_UNENCLOSED_FIELD"] = hex(esc)
|
|
315
|
+
else:
|
|
316
|
+
self.options["ESCAPE_UNENCLOSED_FIELD"] = esc
|
|
317
|
+
return self
|
|
318
|
+
|
|
319
|
+
def field_optionally_enclosed_by(self, enc):
|
|
320
|
+
"""Character used to enclose strings. Either None, ', or \"."""
|
|
321
|
+
_available_options = [None, "'", '"']
|
|
322
|
+
if enc not in _available_options:
|
|
323
|
+
raise TypeError(f"Enclosing string should be one of : {_available_options}")
|
|
324
|
+
self.options["FIELD_OPTIONALLY_ENCLOSED_BY"] = enc
|
|
325
|
+
return self
|
|
326
|
+
|
|
327
|
+
def null_if(self, null_value):
|
|
328
|
+
"""Copying into a table these strings will be replaced by a NULL, while copying out of Snowflake will replace
|
|
329
|
+
NULL values with the first string"""
|
|
330
|
+
if not isinstance(null_value, Sequence):
|
|
331
|
+
raise TypeError("Parameter null_value should be an iterable")
|
|
332
|
+
self.options["NULL_IF"] = tuple(null_value)
|
|
333
|
+
return self
|
|
334
|
+
|
|
335
|
+
def skip_header(self, skip_header):
|
|
336
|
+
"""
|
|
337
|
+
Number of header rows to be skipped at the beginning of the file
|
|
338
|
+
"""
|
|
339
|
+
if not isinstance(skip_header, int):
|
|
340
|
+
raise TypeError("skip_header should be an int")
|
|
341
|
+
self.options["SKIP_HEADER"] = skip_header
|
|
342
|
+
return self
|
|
343
|
+
|
|
344
|
+
def trim_space(self, trim_space):
|
|
345
|
+
"""
|
|
346
|
+
Remove leading or trailing white spaces
|
|
347
|
+
"""
|
|
348
|
+
if not isinstance(trim_space, bool):
|
|
349
|
+
raise TypeError("trim_space should be a bool")
|
|
350
|
+
self.options["TRIM_SPACE"] = trim_space
|
|
351
|
+
return self
|
|
352
|
+
|
|
353
|
+
def error_on_column_count_mismatch(self, error_on_col_count_mismatch):
|
|
354
|
+
"""
|
|
355
|
+
Generate a parsing error if the number of delimited columns (i.e. fields) in
|
|
356
|
+
an input data file does not match the number of columns in the corresponding table.
|
|
357
|
+
"""
|
|
358
|
+
if not isinstance(error_on_col_count_mismatch, bool):
|
|
359
|
+
raise TypeError("skip_header should be a bool")
|
|
360
|
+
self.options["ERROR_ON_COLUMN_COUNT_MISMATCH"] = error_on_col_count_mismatch
|
|
361
|
+
return self
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class JSONFormatter(CopyFormatter):
|
|
365
|
+
"""Format specific functions"""
|
|
366
|
+
|
|
367
|
+
file_format = "json"
|
|
368
|
+
|
|
369
|
+
def compression(self, comp_type):
|
|
370
|
+
"""String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm."""
|
|
371
|
+
if isinstance(comp_type, string_types):
|
|
372
|
+
comp_type = comp_type.lower()
|
|
373
|
+
_available_options = [
|
|
374
|
+
"auto",
|
|
375
|
+
"gzip",
|
|
376
|
+
"bz2",
|
|
377
|
+
"brotli",
|
|
378
|
+
"zstd",
|
|
379
|
+
"deflate",
|
|
380
|
+
"raw_deflate",
|
|
381
|
+
None,
|
|
382
|
+
]
|
|
383
|
+
if comp_type not in _available_options:
|
|
384
|
+
raise TypeError(f"Compression type should be one of : {_available_options}")
|
|
385
|
+
self.options["COMPRESSION"] = comp_type
|
|
386
|
+
return self
|
|
387
|
+
|
|
388
|
+
def file_extension(self, ext):
|
|
389
|
+
"""String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is
|
|
390
|
+
responsible for specifying a valid file extension that can be read by the desired software or service.
|
|
391
|
+
"""
|
|
392
|
+
if not isinstance(ext, (NoneType, string_types)):
|
|
393
|
+
raise TypeError("File extension should be a string")
|
|
394
|
+
self.options["FILE_EXTENSION"] = ext
|
|
395
|
+
return self
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class PARQUETFormatter(CopyFormatter):
|
|
399
|
+
"""Format specific functions"""
|
|
400
|
+
|
|
401
|
+
file_format = "parquet"
|
|
402
|
+
|
|
403
|
+
def snappy_compression(self, comp):
|
|
404
|
+
"""Enable, or disable snappy compression"""
|
|
405
|
+
if not isinstance(comp, bool):
|
|
406
|
+
raise TypeError("Comp should be a Boolean value")
|
|
407
|
+
self.options["SNAPPY_COMPRESSION"] = translate_bool(comp)
|
|
408
|
+
return self
|
|
409
|
+
|
|
410
|
+
def compression(self, comp):
|
|
411
|
+
"""
|
|
412
|
+
Set compression type
|
|
413
|
+
"""
|
|
414
|
+
if not isinstance(comp, str):
|
|
415
|
+
raise TypeError("Comp should be a str value")
|
|
416
|
+
self.options["COMPRESSION"] = comp
|
|
417
|
+
return self
|
|
418
|
+
|
|
419
|
+
def binary_as_text(self, value):
|
|
420
|
+
"""Enable, or disable binary as text"""
|
|
421
|
+
if not isinstance(value, bool):
|
|
422
|
+
raise TypeError("binary_as_text should be a Boolean value")
|
|
423
|
+
self.options["BINARY_AS_TEXT"] = translate_bool(value)
|
|
424
|
+
return self
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
class ExternalStage(ClauseElement, FromClauseRole):
|
|
428
|
+
"""External Stage descriptor"""
|
|
429
|
+
|
|
430
|
+
__visit_name__ = "external_stage"
|
|
431
|
+
_hide_froms = ()
|
|
432
|
+
|
|
433
|
+
@staticmethod
|
|
434
|
+
def prepare_namespace(namespace):
|
|
435
|
+
return f"{namespace}." if not namespace.endswith(".") else namespace
|
|
436
|
+
|
|
437
|
+
@staticmethod
|
|
438
|
+
def prepare_path(path):
|
|
439
|
+
return f"/{path}" if not path.startswith("/") else path
|
|
440
|
+
|
|
441
|
+
def __init__(self, name, path=None, namespace=None, file_format=None):
|
|
442
|
+
self.name = name
|
|
443
|
+
self.path = self.prepare_path(path) if path else ""
|
|
444
|
+
self.namespace = self.prepare_namespace(namespace) if namespace else ""
|
|
445
|
+
self.file_format = file_format
|
|
446
|
+
|
|
447
|
+
def __repr__(self):
|
|
448
|
+
return f"@{self.namespace}{self.name}{self.path} ({self.file_format})"
|
|
449
|
+
|
|
450
|
+
@classmethod
|
|
451
|
+
def from_parent_stage(cls, parent_stage, path, file_format=None):
|
|
452
|
+
"""
|
|
453
|
+
Extend an existing parent stage (with or without path) with an
|
|
454
|
+
additional sub-path
|
|
455
|
+
"""
|
|
456
|
+
return cls(
|
|
457
|
+
parent_stage.name,
|
|
458
|
+
f"{parent_stage.path}/{path}",
|
|
459
|
+
parent_stage.namespace,
|
|
460
|
+
file_format,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
class CreateFileFormat(DDLElement):
|
|
465
|
+
"""
|
|
466
|
+
Encapsulates a CREATE FILE FORMAT statement; using a format description (as in
|
|
467
|
+
a COPY INTO statement) and a format name.
|
|
468
|
+
"""
|
|
469
|
+
|
|
470
|
+
__visit_name__ = "create_file_format"
|
|
471
|
+
|
|
472
|
+
def __init__(self, format_name, formatter, replace_if_exists=False):
|
|
473
|
+
super().__init__()
|
|
474
|
+
self.format_name = format_name
|
|
475
|
+
self.formatter = formatter
|
|
476
|
+
self.replace_if_exists = replace_if_exists
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
class CreateStage(DDLElement):
|
|
480
|
+
"""
|
|
481
|
+
Encapsulates a CREATE STAGE statement, using a container (physical base for the
|
|
482
|
+
stage) and the actual ExternalStage object.
|
|
483
|
+
"""
|
|
484
|
+
|
|
485
|
+
__visit_name__ = "create_stage"
|
|
486
|
+
|
|
487
|
+
def __init__(self, container, stage, replace_if_exists=False, *, temporary=False):
|
|
488
|
+
super().__init__()
|
|
489
|
+
self.container = container
|
|
490
|
+
self.temporary = temporary
|
|
491
|
+
self.stage = stage
|
|
492
|
+
self.replace_if_exists = replace_if_exists
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
class AWSBucket(ClauseElement):
|
|
496
|
+
"""AWS S3 bucket descriptor"""
|
|
497
|
+
|
|
498
|
+
__visit_name__ = "aws_bucket"
|
|
499
|
+
|
|
500
|
+
def __init__(self, bucket, path=None):
|
|
501
|
+
self.bucket = bucket
|
|
502
|
+
self.path = path
|
|
503
|
+
self.encryption_used = {}
|
|
504
|
+
self.credentials_used = {}
|
|
505
|
+
|
|
506
|
+
@classmethod
|
|
507
|
+
def from_uri(cls, uri):
|
|
508
|
+
if uri[0:5] != "s3://":
|
|
509
|
+
raise ValueError(f"Invalid AWS bucket URI: {uri}")
|
|
510
|
+
b = uri[5:].split("/", 1)
|
|
511
|
+
if len(b) == 1:
|
|
512
|
+
bucket, path = b[0], None
|
|
513
|
+
else:
|
|
514
|
+
bucket, path = b
|
|
515
|
+
return cls(bucket, path)
|
|
516
|
+
|
|
517
|
+
def __repr__(self):
|
|
518
|
+
credentials = "CREDENTIALS=({})".format(
|
|
519
|
+
" ".join(f"{n}='{v}'" for n, v in self.credentials_used.items())
|
|
520
|
+
)
|
|
521
|
+
encryption = "ENCRYPTION=({})".format(
|
|
522
|
+
" ".join(
|
|
523
|
+
f"{n}='{v}'" if isinstance(v, string_types) else f"{n}={v}"
|
|
524
|
+
for n, v in self.encryption_used.items()
|
|
525
|
+
)
|
|
526
|
+
)
|
|
527
|
+
uri = "'s3://{}{}'".format(self.bucket, f"/{self.path}" if self.path else "")
|
|
528
|
+
return "{}{}{}".format(
|
|
529
|
+
uri,
|
|
530
|
+
f" {credentials}" if self.credentials_used else "",
|
|
531
|
+
f" {encryption}" if self.encryption_used else "",
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
def credentials(
|
|
535
|
+
self, aws_role=None, aws_key_id=None, aws_secret_key=None, aws_token=None
|
|
536
|
+
):
|
|
537
|
+
if aws_role is None and (aws_key_id is None and aws_secret_key is None):
|
|
538
|
+
raise ValueError(
|
|
539
|
+
"Either 'aws_role', or aws_key_id and aws_secret_key has to be supplied"
|
|
540
|
+
)
|
|
541
|
+
if aws_role:
|
|
542
|
+
self.credentials_used = {"AWS_ROLE": aws_role}
|
|
543
|
+
else:
|
|
544
|
+
self.credentials_used = {
|
|
545
|
+
"AWS_SECRET_KEY": aws_secret_key,
|
|
546
|
+
"AWS_KEY_ID": aws_key_id,
|
|
547
|
+
}
|
|
548
|
+
if aws_token:
|
|
549
|
+
self.credentials_used["AWS_TOKEN"] = aws_token
|
|
550
|
+
return self
|
|
551
|
+
|
|
552
|
+
def encryption_aws_cse(self, master_key):
|
|
553
|
+
self.encryption_used = {"TYPE": "AWS_CSE", "MASTER_KEY": master_key}
|
|
554
|
+
return self
|
|
555
|
+
|
|
556
|
+
def encryption_aws_sse_s3(self):
|
|
557
|
+
self.encryption_used = {"TYPE": "AWS_SSE_S3"}
|
|
558
|
+
return self
|
|
559
|
+
|
|
560
|
+
def encryption_aws_sse_kms(self, kms_key_id=None):
|
|
561
|
+
self.encryption_used = {"TYPE": "AWS_SSE_KMS"}
|
|
562
|
+
if kms_key_id:
|
|
563
|
+
self.encryption_used["KMS_KEY_ID"] = kms_key_id
|
|
564
|
+
return self
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
class AzureContainer(ClauseElement):
|
|
568
|
+
"""Microsoft Azure Container descriptor"""
|
|
569
|
+
|
|
570
|
+
__visit_name__ = "azure_container"
|
|
571
|
+
|
|
572
|
+
def __init__(self, account, container, path=None):
|
|
573
|
+
self.account = account
|
|
574
|
+
self.container = container
|
|
575
|
+
self.path = path
|
|
576
|
+
self.encryption_used = {}
|
|
577
|
+
self.credentials_used = {}
|
|
578
|
+
|
|
579
|
+
@classmethod
|
|
580
|
+
def from_uri(cls, uri):
|
|
581
|
+
if uri[0:8] != "azure://":
|
|
582
|
+
raise ValueError(f"Invalid Azure Container URI: {uri}")
|
|
583
|
+
account, uri = uri[8:].split(".", 1)
|
|
584
|
+
if uri[0:22] != "blob.core.windows.net/":
|
|
585
|
+
raise ValueError(f"Invalid Azure Container URI: {uri}")
|
|
586
|
+
b = uri[22:].split("/", 1)
|
|
587
|
+
if len(b) == 1:
|
|
588
|
+
container, path = b[0], None
|
|
589
|
+
else:
|
|
590
|
+
container, path = b
|
|
591
|
+
return cls(account, container, path)
|
|
592
|
+
|
|
593
|
+
def __repr__(self):
|
|
594
|
+
credentials = "CREDENTIALS=({})".format(
|
|
595
|
+
" ".join(f"{n}='{v}'" for n, v in self.credentials_used.items())
|
|
596
|
+
)
|
|
597
|
+
encryption = "ENCRYPTION=({})".format(
|
|
598
|
+
" ".join(
|
|
599
|
+
f"{n}='{v}'" if isinstance(v, string_types) else f"{n}={v}"
|
|
600
|
+
for n, v in self.encryption_used.items()
|
|
601
|
+
)
|
|
602
|
+
)
|
|
603
|
+
uri = "'azure://{}.blob.core.windows.net/{}{}'".format(
|
|
604
|
+
self.account, self.container, f"/{self.path}" if self.path else ""
|
|
605
|
+
)
|
|
606
|
+
return "{}{}{}".format(
|
|
607
|
+
uri,
|
|
608
|
+
f" {credentials}" if self.credentials_used else "",
|
|
609
|
+
f" {encryption}" if self.encryption_used else "",
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
def credentials(self, azure_sas_token):
|
|
613
|
+
self.credentials_used = {"AZURE_SAS_TOKEN": azure_sas_token}
|
|
614
|
+
return self
|
|
615
|
+
|
|
616
|
+
def encryption_azure_cse(self, master_key):
|
|
617
|
+
self.encryption_used = {"TYPE": "AZURE_CSE", "MASTER_KEY": master_key}
|
|
618
|
+
return self
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
CopyIntoStorage = CopyInto
|