snowflake-sqlalchemy 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. snowflake/sqlalchemy/__init__.py +162 -0
  2. snowflake/sqlalchemy/_constants.py +14 -0
  3. snowflake/sqlalchemy/base.py +1188 -0
  4. snowflake/sqlalchemy/compat.py +36 -0
  5. snowflake/sqlalchemy/custom_commands.py +627 -0
  6. snowflake/sqlalchemy/custom_types.py +155 -0
  7. snowflake/sqlalchemy/exc.py +82 -0
  8. snowflake/sqlalchemy/functions.py +16 -0
  9. snowflake/sqlalchemy/parser/custom_type_parser.py +245 -0
  10. snowflake/sqlalchemy/provision.py +12 -0
  11. snowflake/sqlalchemy/requirements.py +313 -0
  12. snowflake/sqlalchemy/snowdialect.py +1029 -0
  13. snowflake/sqlalchemy/sql/__init__.py +3 -0
  14. snowflake/sqlalchemy/sql/custom_schema/__init__.py +9 -0
  15. snowflake/sqlalchemy/sql/custom_schema/clustered_table.py +37 -0
  16. snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py +127 -0
  17. snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py +13 -0
  18. snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py +117 -0
  19. snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py +63 -0
  20. snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py +102 -0
  21. snowflake/sqlalchemy/sql/custom_schema/options/__init__.py +33 -0
  22. snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py +63 -0
  23. snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py +58 -0
  24. snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py +63 -0
  25. snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py +25 -0
  26. snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py +65 -0
  27. snowflake/sqlalchemy/sql/custom_schema/options/keywords.py +14 -0
  28. snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py +67 -0
  29. snowflake/sqlalchemy/sql/custom_schema/options/table_option.py +84 -0
  30. snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py +94 -0
  31. snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py +70 -0
  32. snowflake/sqlalchemy/sql/custom_schema/table_from_query.py +54 -0
  33. snowflake/sqlalchemy/util.py +344 -0
  34. snowflake/sqlalchemy/version.py +6 -0
  35. snowflake_sqlalchemy-1.7.3.dist-info/METADATA +737 -0
  36. snowflake_sqlalchemy-1.7.3.dist-info/RECORD +39 -0
  37. snowflake_sqlalchemy-1.7.3.dist-info/WHEEL +4 -0
  38. snowflake_sqlalchemy-1.7.3.dist-info/entry_points.txt +2 -0
  39. snowflake_sqlalchemy-1.7.3.dist-info/licenses/LICENSE.txt +202 -0
@@ -0,0 +1,36 @@
1
+ #
2
+ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ from typing import Callable
7
+
8
+ from sqlalchemy import __version__ as SA_VERSION
9
+ from sqlalchemy import util
10
+
11
+ string_types = (str,)
12
+ returns_unicode = util.symbol("RETURNS_UNICODE")
13
+
14
+ IS_VERSION_20 = tuple(int(v) for v in SA_VERSION.split(".")) >= (2, 0, 0)
15
+
16
+
17
+ def args_reducer(positions_to_drop: tuple):
18
+ """Removes args at positions provided in tuple positions_to_drop.
19
+
20
+ For example tuple (3, 5) will remove items at third and fifth position.
21
+ Keep in mind that on class methods first postion is cls or self.
22
+ """
23
+
24
+ def fn_wrapper(fn: Callable):
25
+ @functools.wraps(fn)
26
+ def wrapper(*args):
27
+ reduced_args = args
28
+ if not IS_VERSION_20:
29
+ reduced_args = tuple(
30
+ arg for idx, arg in enumerate(args) if idx not in positions_to_drop
31
+ )
32
+ fn(*reduced_args)
33
+
34
+ return wrapper
35
+
36
+ return fn_wrapper
@@ -0,0 +1,627 @@
1
+ #
2
+ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ from collections.abc import Sequence
6
+ from typing import List
7
+
8
+ from sqlalchemy import false, true
9
+ from sqlalchemy.sql.ddl import DDLElement
10
+ from sqlalchemy.sql.dml import UpdateBase
11
+ from sqlalchemy.sql.elements import ClauseElement
12
+ from sqlalchemy.sql.roles import FromClauseRole
13
+
14
+ from .compat import string_types
15
+
16
+ NoneType = type(None)
17
+
18
+
19
+ def translate_bool(bln):
20
+ if bln:
21
+ return true()
22
+ return false()
23
+
24
+
25
+ class MergeInto(UpdateBase):
26
+ __visit_name__ = "merge_into"
27
+ _bind = None
28
+
29
+ def __init__(self, target, source, on):
30
+ self.target = target
31
+ self.source = source
32
+ self.on = on
33
+ self.clauses = []
34
+
35
+ class clause(ClauseElement):
36
+ __visit_name__ = "merge_into_clause"
37
+
38
+ def __init__(self, command):
39
+ self.set = {}
40
+ self.predicate = None
41
+ self.command = command
42
+
43
+ def __repr__(self):
44
+ case_predicate = (
45
+ f" AND {str(self.predicate)}" if self.predicate is not None else ""
46
+ )
47
+ if self.command == "INSERT":
48
+ sets, sets_tos = zip(*self.set.items())
49
+ return "WHEN NOT MATCHED{} THEN {} ({}) VALUES ({})".format(
50
+ case_predicate,
51
+ self.command,
52
+ ", ".join(sets),
53
+ ", ".join(map(str, sets_tos)),
54
+ )
55
+ else:
56
+ # WHEN MATCHED clause
57
+ sets = (
58
+ ", ".join([f"{set[0]} = {set[1]}" for set in self.set.items()])
59
+ if self.set
60
+ else ""
61
+ )
62
+ return "WHEN MATCHED{} THEN {}{}".format(
63
+ case_predicate,
64
+ self.command,
65
+ f" SET {str(sets)}" if self.set else "",
66
+ )
67
+
68
+ def values(self, **kwargs):
69
+ self.set = kwargs
70
+ return self
71
+
72
+ def where(self, expr):
73
+ self.predicate = expr
74
+ return self
75
+
76
+ def __repr__(self):
77
+ clauses = " ".join([repr(clause) for clause in self.clauses])
78
+ return f"MERGE INTO {self.target} USING {self.source} ON {self.on}" + (
79
+ f" {clauses}" if clauses else ""
80
+ )
81
+
82
+ def when_matched_then_update(self):
83
+ clause = self.clause("UPDATE")
84
+ self.clauses.append(clause)
85
+ return clause
86
+
87
+ def when_matched_then_delete(self):
88
+ clause = self.clause("DELETE")
89
+ self.clauses.append(clause)
90
+ return clause
91
+
92
+ def when_not_matched_then_insert(self):
93
+ clause = self.clause("INSERT")
94
+ self.clauses.append(clause)
95
+ return clause
96
+
97
+
98
+ class FilesOption:
99
+ """
100
+ Class to represent FILES option for the snowflake COPY INTO statement
101
+ """
102
+
103
+ def __init__(self, file_names: List[str]):
104
+ self.file_names = file_names
105
+
106
+ def __str__(self):
107
+ the_files = ["'" + f.replace("'", "\\'") + "'" for f in self.file_names]
108
+ return f"({','.join(the_files)})"
109
+
110
+
111
+ class CopyInto(UpdateBase):
112
+ """Copy Into Command base class, for documentation see:
113
+ https://docs.snowflake.net/manuals/sql-reference/sql/copy-into-location.html"""
114
+
115
+ __visit_name__ = "copy_into"
116
+ _bind = None
117
+
118
+ def __init__(self, from_, into, partition_by=None, formatter=None):
119
+ self.from_ = from_
120
+ self.into = into
121
+ self.formatter = formatter
122
+ self.copy_options = {}
123
+ self.partition_by = partition_by
124
+
125
+ def __repr__(self):
126
+ """
127
+ repr for debugging / logging purposes only. For compilation logic, see
128
+ the corresponding visitor in base.py
129
+ """
130
+ val = f"COPY INTO {self.into} FROM {repr(self.from_)}"
131
+ if self.partition_by is not None:
132
+ val += f" PARTITION BY {self.partition_by}"
133
+
134
+ return val + f" {repr(self.formatter)} ({self.copy_options})"
135
+
136
+ def bind(self):
137
+ return None
138
+
139
+ def force(self, force):
140
+ if not isinstance(force, bool):
141
+ raise TypeError("Parameter force should be a boolean value")
142
+ self.copy_options.update({"FORCE": translate_bool(force)})
143
+ return self
144
+
145
+ def single(self, single_file):
146
+ if not isinstance(single_file, bool):
147
+ raise TypeError("Parameter single_file should be a boolean value")
148
+ self.copy_options.update({"SINGLE": translate_bool(single_file)})
149
+ return self
150
+
151
+ def maxfilesize(self, max_size):
152
+ if not isinstance(max_size, int):
153
+ raise TypeError("Parameter max_size should be an integer value")
154
+ self.copy_options.update({"MAX_FILE_SIZE": max_size})
155
+ return self
156
+
157
+ def files(self, file_names):
158
+ self.copy_options.update({"FILES": FilesOption(file_names)})
159
+ return self
160
+
161
+ def pattern(self, pattern):
162
+ self.copy_options.update({"PATTERN": pattern})
163
+ return self
164
+
165
+
166
+ class CopyFormatter(ClauseElement):
167
+ """
168
+ Base class for Formatter specifications inside a COPY INTO statement. May also
169
+ be used to create a named format.
170
+ """
171
+
172
+ __visit_name__ = "copy_formatter"
173
+
174
+ def __init__(self, format_name=None):
175
+ self.options = dict()
176
+ if format_name:
177
+ self.options["format_name"] = format_name
178
+
179
+ def __repr__(self):
180
+ """
181
+ repr for debugging / logging purposes only. For compilation logic, see
182
+ the corresponding visitor in base.py
183
+ """
184
+ return f"FILE_FORMAT=({self.options})"
185
+
186
+ @staticmethod
187
+ def value_repr(name, value):
188
+ """
189
+ Make a SQL-suitable representation of "value". This is called from
190
+ the corresponding visitor function (base.py/visit_copy_formatter())
191
+ - in case of a format name: return it without quotes
192
+ - in case of a string: enclose in quotes: "value"
193
+ - in case of a tuple of length 1: enclose the only element in brackets: (value)
194
+ Standard stringification of Python would append a trailing comma: (value,)
195
+ which is not correct in SQL
196
+ - otherwise: just convert to str as is: value
197
+ """
198
+ if name == "format_name":
199
+ return value
200
+ elif isinstance(value, str):
201
+ return f"'{value}'"
202
+ elif isinstance(value, tuple) and len(value) == 1:
203
+ return f"('{value[0]}')"
204
+ else:
205
+ return str(value)
206
+
207
+
208
+ class CSVFormatter(CopyFormatter):
209
+ file_format = "csv"
210
+
211
+ def compression(self, comp_type):
212
+ """String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm."""
213
+ if isinstance(comp_type, string_types):
214
+ comp_type = comp_type.lower()
215
+ _available_options = [
216
+ "auto",
217
+ "gzip",
218
+ "bz2",
219
+ "brotli",
220
+ "zstd",
221
+ "deflate",
222
+ "raw_deflate",
223
+ None,
224
+ ]
225
+ if comp_type not in _available_options:
226
+ raise TypeError(f"Compression type should be one of : {_available_options}")
227
+ self.options["COMPRESSION"] = comp_type
228
+ return self
229
+
230
+ def _check_delimiter(self, delimiter, delimiter_txt):
231
+ """
232
+ Check if a delimiter is either a string of length 1 or an integer. In case of
233
+ a string delimiter, take into account that the actual string may be longer,
234
+ but still evaluate to a single character (like "\\n" or r"\n"
235
+ """
236
+ if isinstance(delimiter, NoneType):
237
+ return
238
+ if isinstance(delimiter, string_types):
239
+ delimiter_processed = delimiter.encode().decode("unicode_escape")
240
+ if len(delimiter_processed) == 1:
241
+ return
242
+ if isinstance(delimiter, int):
243
+ return
244
+ raise TypeError(
245
+ f"{delimiter_txt} should be a single character, that is either a string, or a number"
246
+ )
247
+
248
+ def record_delimiter(self, deli_type):
249
+ """Character that separates records in an unloaded file."""
250
+ self._check_delimiter(deli_type, "Record delimiter")
251
+ if isinstance(deli_type, int):
252
+ self.options["RECORD_DELIMITER"] = hex(deli_type)
253
+ else:
254
+ self.options["RECORD_DELIMITER"] = deli_type
255
+ return self
256
+
257
+ def field_delimiter(self, deli_type):
258
+ """Character that separates fields in an unloaded file."""
259
+ self._check_delimiter(deli_type, "Field delimiter")
260
+ if isinstance(deli_type, int):
261
+ self.options["FIELD_DELIMITER"] = hex(deli_type)
262
+ else:
263
+ self.options["FIELD_DELIMITER"] = deli_type
264
+ return self
265
+
266
+ def file_extension(self, ext):
267
+ """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is
268
+ responsible for specifying a valid file extension that can be read by the desired software or service.
269
+ """
270
+ if not isinstance(ext, (NoneType, string_types)):
271
+ raise TypeError("File extension should be a string")
272
+ self.options["FILE_EXTENSION"] = ext
273
+ return self
274
+
275
+ def date_format(self, dt_frmt):
276
+ """String that defines the format of date values in the unloaded data files."""
277
+ if not isinstance(dt_frmt, string_types):
278
+ raise TypeError("Date format should be a string")
279
+ self.options["DATE_FORMAT"] = dt_frmt
280
+ return self
281
+
282
+ def time_format(self, tm_frmt):
283
+ """String that defines the format of time values in the unloaded data files."""
284
+ if not isinstance(tm_frmt, string_types):
285
+ raise TypeError("Time format should be a string")
286
+ self.options["TIME_FORMAT"] = tm_frmt
287
+ return self
288
+
289
+ def timestamp_format(self, tmstmp_frmt):
290
+ """String that defines the format of timestamp values in the unloaded data files."""
291
+ if not isinstance(tmstmp_frmt, string_types):
292
+ raise TypeError("Timestamp format should be a string")
293
+ self.options["TIMESTAMP_FORMAT"] = tmstmp_frmt
294
+ return self
295
+
296
+ def binary_format(self, bin_fmt):
297
+ """Character used as the escape character for any field values. The option can be used when unloading data
298
+ from binary columns in a table."""
299
+ if isinstance(bin_fmt, string_types):
300
+ bin_fmt = bin_fmt.lower()
301
+ _available_options = ["hex", "base64", "utf8"]
302
+ if bin_fmt not in _available_options:
303
+ raise TypeError(f"Binary format should be one of : {_available_options}")
304
+ self.options["BINARY_FORMAT"] = bin_fmt
305
+ return self
306
+
307
+ def escape(self, esc):
308
+ """Character used as the escape character for any field values."""
309
+ self._check_delimiter(esc, "Escape")
310
+ if isinstance(esc, int):
311
+ self.options["ESCAPE"] = hex(esc)
312
+ else:
313
+ self.options["ESCAPE"] = esc
314
+ return self
315
+
316
+ def escape_unenclosed_field(self, esc):
317
+ """Single character string used as the escape character for unenclosed field values only."""
318
+ self._check_delimiter(esc, "Escape unenclosed field")
319
+ if isinstance(esc, int):
320
+ self.options["ESCAPE_UNENCLOSED_FIELD"] = hex(esc)
321
+ else:
322
+ self.options["ESCAPE_UNENCLOSED_FIELD"] = esc
323
+ return self
324
+
325
+ def field_optionally_enclosed_by(self, enc):
326
+ """Character used to enclose strings. Either None, ', or \"."""
327
+ _available_options = [None, "'", '"']
328
+ if enc not in _available_options:
329
+ raise TypeError(f"Enclosing string should be one of : {_available_options}")
330
+ self.options["FIELD_OPTIONALLY_ENCLOSED_BY"] = enc
331
+ return self
332
+
333
+ def null_if(self, null_value):
334
+ """Copying into a table these strings will be replaced by a NULL, while copying out of Snowflake will replace
335
+ NULL values with the first string"""
336
+ if not isinstance(null_value, Sequence):
337
+ raise TypeError("Parameter null_value should be an iterable")
338
+ self.options["NULL_IF"] = tuple(null_value)
339
+ return self
340
+
341
+ def skip_header(self, skip_header):
342
+ """
343
+ Number of header rows to be skipped at the beginning of the file
344
+ """
345
+ if not isinstance(skip_header, int):
346
+ raise TypeError("skip_header should be an int")
347
+ self.options["SKIP_HEADER"] = skip_header
348
+ return self
349
+
350
+ def trim_space(self, trim_space):
351
+ """
352
+ Remove leading or trailing white spaces
353
+ """
354
+ if not isinstance(trim_space, bool):
355
+ raise TypeError("trim_space should be a bool")
356
+ self.options["TRIM_SPACE"] = trim_space
357
+ return self
358
+
359
+ def error_on_column_count_mismatch(self, error_on_col_count_mismatch):
360
+ """
361
+ Generate a parsing error if the number of delimited columns (i.e. fields) in
362
+ an input data file does not match the number of columns in the corresponding table.
363
+ """
364
+ if not isinstance(error_on_col_count_mismatch, bool):
365
+ raise TypeError("skip_header should be a bool")
366
+ self.options["ERROR_ON_COLUMN_COUNT_MISMATCH"] = error_on_col_count_mismatch
367
+ return self
368
+
369
+
370
+ class JSONFormatter(CopyFormatter):
371
+ """Format specific functions"""
372
+
373
+ file_format = "json"
374
+
375
+ def compression(self, comp_type):
376
+ """String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm."""
377
+ if isinstance(comp_type, string_types):
378
+ comp_type = comp_type.lower()
379
+ _available_options = [
380
+ "auto",
381
+ "gzip",
382
+ "bz2",
383
+ "brotli",
384
+ "zstd",
385
+ "deflate",
386
+ "raw_deflate",
387
+ None,
388
+ ]
389
+ if comp_type not in _available_options:
390
+ raise TypeError(f"Compression type should be one of : {_available_options}")
391
+ self.options["COMPRESSION"] = comp_type
392
+ return self
393
+
394
+ def file_extension(self, ext):
395
+ """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is
396
+ responsible for specifying a valid file extension that can be read by the desired software or service.
397
+ """
398
+ if not isinstance(ext, (NoneType, string_types)):
399
+ raise TypeError("File extension should be a string")
400
+ self.options["FILE_EXTENSION"] = ext
401
+ return self
402
+
403
+
404
+ class PARQUETFormatter(CopyFormatter):
405
+ """Format specific functions"""
406
+
407
+ file_format = "parquet"
408
+
409
+ def snappy_compression(self, comp):
410
+ """Enable, or disable snappy compression"""
411
+ if not isinstance(comp, bool):
412
+ raise TypeError("Comp should be a Boolean value")
413
+ self.options["SNAPPY_COMPRESSION"] = translate_bool(comp)
414
+ return self
415
+
416
+ def compression(self, comp):
417
+ """
418
+ Set compression type
419
+ """
420
+ if not isinstance(comp, str):
421
+ raise TypeError("Comp should be a str value")
422
+ self.options["COMPRESSION"] = comp
423
+ return self
424
+
425
+ def binary_as_text(self, value):
426
+ """Enable, or disable binary as text"""
427
+ if not isinstance(value, bool):
428
+ raise TypeError("binary_as_text should be a Boolean value")
429
+ self.options["BINARY_AS_TEXT"] = translate_bool(value)
430
+ return self
431
+
432
+
433
+ class ExternalStage(ClauseElement, FromClauseRole):
434
+ """External Stage descriptor"""
435
+
436
+ __visit_name__ = "external_stage"
437
+ _hide_froms = ()
438
+
439
+ @staticmethod
440
+ def prepare_namespace(namespace):
441
+ return f"{namespace}." if not namespace.endswith(".") else namespace
442
+
443
+ @staticmethod
444
+ def prepare_path(path):
445
+ return f"/{path}" if not path.startswith("/") else path
446
+
447
+ def __init__(self, name, path=None, namespace=None, file_format=None):
448
+ self.name = name
449
+ self.path = self.prepare_path(path) if path else ""
450
+ self.namespace = self.prepare_namespace(namespace) if namespace else ""
451
+ self.file_format = file_format
452
+
453
+ def __repr__(self):
454
+ return f"@{self.namespace}{self.name}{self.path} ({self.file_format})"
455
+
456
+ @classmethod
457
+ def from_parent_stage(cls, parent_stage, path, file_format=None):
458
+ """
459
+ Extend an existing parent stage (with or without path) with an
460
+ additional sub-path
461
+ """
462
+ return cls(
463
+ parent_stage.name,
464
+ f"{parent_stage.path}/{path}",
465
+ parent_stage.namespace,
466
+ file_format,
467
+ )
468
+
469
+
470
+ class CreateFileFormat(DDLElement):
471
+ """
472
+ Encapsulates a CREATE FILE FORMAT statement; using a format description (as in
473
+ a COPY INTO statement) and a format name.
474
+ """
475
+
476
+ __visit_name__ = "create_file_format"
477
+
478
+ def __init__(self, format_name, formatter, replace_if_exists=False):
479
+ super().__init__()
480
+ self.format_name = format_name
481
+ self.formatter = formatter
482
+ self.replace_if_exists = replace_if_exists
483
+
484
+
485
+ class CreateStage(DDLElement):
486
+ """
487
+ Encapsulates a CREATE STAGE statement, using a container (physical base for the
488
+ stage) and the actual ExternalStage object.
489
+ """
490
+
491
+ __visit_name__ = "create_stage"
492
+
493
+ def __init__(self, container, stage, replace_if_exists=False, *, temporary=False):
494
+ super().__init__()
495
+ self.container = container
496
+ self.temporary = temporary
497
+ self.stage = stage
498
+ self.replace_if_exists = replace_if_exists
499
+
500
+
501
+ class AWSBucket(ClauseElement):
502
+ """AWS S3 bucket descriptor"""
503
+
504
+ __visit_name__ = "aws_bucket"
505
+
506
+ def __init__(self, bucket, path=None):
507
+ self.bucket = bucket
508
+ self.path = path
509
+ self.encryption_used = {}
510
+ self.credentials_used = {}
511
+
512
+ @classmethod
513
+ def from_uri(cls, uri):
514
+ if uri[0:5] != "s3://":
515
+ raise ValueError(f"Invalid AWS bucket URI: {uri}")
516
+ b = uri[5:].split("/", 1)
517
+ if len(b) == 1:
518
+ bucket, path = b[0], None
519
+ else:
520
+ bucket, path = b
521
+ return cls(bucket, path)
522
+
523
+ def __repr__(self):
524
+ credentials = "CREDENTIALS=({})".format(
525
+ " ".join(f"{n}='{v}'" for n, v in self.credentials_used.items())
526
+ )
527
+ encryption = "ENCRYPTION=({})".format(
528
+ " ".join(
529
+ f"{n}='{v}'" if isinstance(v, string_types) else f"{n}={v}"
530
+ for n, v in self.encryption_used.items()
531
+ )
532
+ )
533
+ uri = "'s3://{}{}'".format(self.bucket, f"/{self.path}" if self.path else "")
534
+ return "{}{}{}".format(
535
+ uri,
536
+ f" {credentials}" if self.credentials_used else "",
537
+ f" {encryption}" if self.encryption_used else "",
538
+ )
539
+
540
+ def credentials(
541
+ self, aws_role=None, aws_key_id=None, aws_secret_key=None, aws_token=None
542
+ ):
543
+ if aws_role is None and (aws_key_id is None and aws_secret_key is None):
544
+ raise ValueError(
545
+ "Either 'aws_role', or aws_key_id and aws_secret_key has to be supplied"
546
+ )
547
+ if aws_role:
548
+ self.credentials_used = {"AWS_ROLE": aws_role}
549
+ else:
550
+ self.credentials_used = {
551
+ "AWS_SECRET_KEY": aws_secret_key,
552
+ "AWS_KEY_ID": aws_key_id,
553
+ }
554
+ if aws_token:
555
+ self.credentials_used["AWS_TOKEN"] = aws_token
556
+ return self
557
+
558
+ def encryption_aws_cse(self, master_key):
559
+ self.encryption_used = {"TYPE": "AWS_CSE", "MASTER_KEY": master_key}
560
+ return self
561
+
562
+ def encryption_aws_sse_s3(self):
563
+ self.encryption_used = {"TYPE": "AWS_SSE_S3"}
564
+ return self
565
+
566
+ def encryption_aws_sse_kms(self, kms_key_id=None):
567
+ self.encryption_used = {"TYPE": "AWS_SSE_KMS"}
568
+ if kms_key_id:
569
+ self.encryption_used["KMS_KEY_ID"] = kms_key_id
570
+ return self
571
+
572
+
573
+ class AzureContainer(ClauseElement):
574
+ """Microsoft Azure Container descriptor"""
575
+
576
+ __visit_name__ = "azure_container"
577
+
578
+ def __init__(self, account, container, path=None):
579
+ self.account = account
580
+ self.container = container
581
+ self.path = path
582
+ self.encryption_used = {}
583
+ self.credentials_used = {}
584
+
585
+ @classmethod
586
+ def from_uri(cls, uri):
587
+ if uri[0:8] != "azure://":
588
+ raise ValueError(f"Invalid Azure Container URI: {uri}")
589
+ account, uri = uri[8:].split(".", 1)
590
+ if uri[0:22] != "blob.core.windows.net/":
591
+ raise ValueError(f"Invalid Azure Container URI: {uri}")
592
+ b = uri[22:].split("/", 1)
593
+ if len(b) == 1:
594
+ container, path = b[0], None
595
+ else:
596
+ container, path = b
597
+ return cls(account, container, path)
598
+
599
+ def __repr__(self):
600
+ credentials = "CREDENTIALS=({})".format(
601
+ " ".join(f"{n}='{v}'" for n, v in self.credentials_used.items())
602
+ )
603
+ encryption = "ENCRYPTION=({})".format(
604
+ " ".join(
605
+ f"{n}='{v}'" if isinstance(v, string_types) else f"{n}={v}"
606
+ for n, v in self.encryption_used.items()
607
+ )
608
+ )
609
+ uri = "'azure://{}.blob.core.windows.net/{}{}'".format(
610
+ self.account, self.container, f"/{self.path}" if self.path else ""
611
+ )
612
+ return "{}{}{}".format(
613
+ uri,
614
+ f" {credentials}" if self.credentials_used else "",
615
+ f" {encryption}" if self.encryption_used else "",
616
+ )
617
+
618
+ def credentials(self, azure_sas_token):
619
+ self.credentials_used = {"AZURE_SAS_TOKEN": azure_sas_token}
620
+ return self
621
+
622
+ def encryption_azure_cse(self, master_key):
623
+ self.encryption_used = {"TYPE": "AZURE_CSE", "MASTER_KEY": master_key}
624
+ return self
625
+
626
+
627
+ CopyIntoStorage = CopyInto