ml-analytics-tools 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2615 @@
1
+ """
2
+ Generic utility functions for data processing and database connection.
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import re
8
+ import threading
9
+ import time
10
+
11
+ import boto3
12
+ import pandas as pd
13
+ import polars as pl
14
+ import pyarrow as pa
15
+ import pyarrow.parquet as pq
16
+ import redshift_connector
17
+
18
+ from .s3_connector import S3Connector
19
+ from .utils import (
20
+ _split_sql_statements,
21
+ get_credential_value,
22
+ get_logger,
23
+ load_sql_query,
24
+ log_and_raise_error,
25
+ )
26
+
27
+
28
+ class DataConnector:
29
+ def __init__(
30
+ self,
31
+ *,
32
+ database=None,
33
+ user=None,
34
+ password=None,
35
+ host=None,
36
+ port=None,
37
+ s3_bucket=None,
38
+ timeout=240,
39
+ ):
40
+ """
41
+ Initialize a DataConnector instance.
42
+ Connection is established lazily and will time out after a period of inactivity.
43
+ """
44
+ self._db_params = {
45
+ "database": database or get_credential_value("BI_REDSHIFT_DB"),
46
+ "user": user or get_credential_value("BI_REDSHIFT_USER"),
47
+ "password": password or get_credential_value("BI_REDSHIFT_PASSWORD"),
48
+ "host": host or get_credential_value("BI_REDSHIFT_HOST"),
49
+ "port": port or get_credential_value("BI_REDSHIFT_PORT"),
50
+ }
51
+ self._s3_bucket = s3_bucket or os.getenv("ML_ANALYTICS_S3_BUCKET")
52
+ self.connection = None
53
+ self.cursor = None
54
+ self._logger = get_logger("Data Connector")
55
+ self.s3 = None
56
+
57
+ # Cache for S3 connectors by bucket name
58
+ self._s3_connectors: dict[str, S3Connector] = {}
59
+
60
+ # Timeout attributes
61
+ self._timeout = timeout
62
+ self._last_activity = None
63
+ self._idle_timer = None
64
+ # Use a re-entrant lock because some methods call others that also
65
+ # acquire the lock (e.g. _close_if_idle -> close_redshift_connection ->
66
+ # _cancel_idle_timer). An RLock avoids deadlocks in that scenario.
67
+ self._lock = threading.RLock()
68
+
69
+ def _is_connection_open(self) -> bool:
70
+ """Return True when the underlying connection is open/usable.
71
+
72
+ Prefer checking redshift_connector-specific APIs when available for reliable
73
+ results. Fallback to a set of widely used attributes and lastly to a
74
+ lightweight cursor probe.
75
+ """
76
+ # Quick negative checks
77
+ if not getattr(self, "connection", None):
78
+ return False
79
+
80
+ conn = self.connection
81
+ try:
82
+ # If this is a redshift_connector connection prefer its public API
83
+ try:
84
+ module_name = conn.__class__.__module__
85
+ except Exception:
86
+ module_name = ""
87
+
88
+ if module_name.startswith("redshift_connector"):
89
+ # If the connection object became None while we inspected it
90
+ if conn is None:
91
+ return False
92
+
93
+ # redshift_connector exposes is_closed (callable or property)
94
+ if hasattr(conn, "is_closed"):
95
+ val = conn.is_closed
96
+ if callable(val):
97
+ val = val()
98
+ return not bool(val)
99
+
100
+ # Some versions might still have a 'closed' like psycopg2 (int)
101
+ if hasattr(conn, "closed"):
102
+ val = conn.closed
103
+ if callable(val):
104
+ val = val()
105
+ if isinstance(val, int):
106
+ return val == 0
107
+ return not bool(val)
108
+
109
+ # Last resort for redshift_connector: probe via cursor
110
+ cur = getattr(self, "cursor", None)
111
+ if cur is None:
112
+ return False
113
+ try:
114
+ # A small probe that doesn't change transaction state
115
+ cur.execute("SELECT 1")
116
+ # Some cursor implementations keep a reference to the
117
+ # connection even after close; check connection object
118
+ # again for early returns.
119
+ if getattr(self, "connection", None) is None:
120
+ return False
121
+ return True
122
+ except Exception:
123
+ return False
124
+
125
+ # Generic checks for other DB drivers
126
+ # psycopg2: .closed is an int (0=open, non-zero=closed)
127
+ if hasattr(conn, "closed"):
128
+ val = conn.closed
129
+ if callable(val):
130
+ val = val()
131
+ if isinstance(val, int):
132
+ return val == 0
133
+ return not bool(val)
134
+
135
+ # common boolean attributes
136
+ if hasattr(conn, "is_closed"):
137
+ val = conn.is_closed
138
+ if callable(val):
139
+ val = val()
140
+ return not bool(val)
141
+
142
+ if hasattr(conn, "open"):
143
+ val = conn.open
144
+ if callable(val):
145
+ val = val()
146
+ return bool(val)
147
+
148
+ if hasattr(conn, "is_valid"):
149
+ val = conn.is_valid
150
+ if callable(val):
151
+ val = val()
152
+ return bool(val)
153
+
154
+ # Last-resort probe: try a very cheap operation with the cursor.
155
+ cur = getattr(self, "cursor", None)
156
+ if cur is None:
157
+ return False
158
+ try:
159
+ cur.execute("SELECT 1")
160
+ return True
161
+ except Exception:
162
+ return False
163
+ except Exception:
164
+ return False
165
+
166
+ def _start_idle_timer(self):
167
+ """Starts or resets the idle connection timer."""
168
+ with self._lock:
169
+ if self._idle_timer:
170
+ self._idle_timer.cancel()
171
+ self._idle_timer = threading.Timer(self._timeout, self._close_if_idle)
172
+ self._idle_timer.daemon = True
173
+ self._idle_timer.start()
174
+ self._last_activity = time.time()
175
+
176
+ def _cancel_idle_timer(self):
177
+ """Cancels the idle timer if it's active."""
178
+ with self._lock:
179
+ if self._idle_timer:
180
+ try:
181
+ self._idle_timer.cancel()
182
+ except Exception:
183
+ self._logger.debug("Failed to cancel idle timer", exc_info=True)
184
+ self._idle_timer = None
185
+
186
+ def _close_if_idle(self):
187
+ """Callback for the timer to close the connection if it's idle."""
188
+ with self._lock:
189
+ last_activity_snapshot = self._last_activity
190
+ connection_open = self._is_connection_open()
191
+ if connection_open:
192
+ idle_duration = time.time() - last_activity_snapshot
193
+ if idle_duration >= self._timeout:
194
+ self._logger.info(f"Connection has been idle for {idle_duration:.2f} seconds. Closing.")
195
+ self.close_redshift_connection()
196
+
197
+ def _mark_activity(self):
198
+ with self._lock:
199
+ self._last_activity = time.time()
200
+ self._start_idle_timer()
201
+
202
+ def connect(self):
203
+ """Establish a connection to the database if not already connected."""
204
+ with self._lock:
205
+ if self._is_connection_open():
206
+ self._start_idle_timer() # Reset timer on use
207
+ return
208
+
209
+ try:
210
+ self.connection = redshift_connector.connect(**self._db_params)
211
+ self.connection.autocommit = True
212
+ self.cursor = self.connection.cursor()
213
+ self._start_idle_timer() # Start idle timer on new connection
214
+
215
+ # Initialize the default S3 connector only when a default bucket is configured.
216
+ if self._s3_bucket:
217
+ self.s3 = self._get_s3_for_bucket(self._s3_bucket)
218
+
219
+ except Exception as e:
220
+ log_and_raise_error(self._logger, f"Failed to connect to Redshift: {e}")
221
+
222
+ def close_redshift_connection(self):
223
+ """Close the database connection if it is open."""
224
+ # Acquire lock to synchronize with the idle timer closure
225
+ with self._lock:
226
+ self._cancel_idle_timer()
227
+ try:
228
+ if self._is_connection_open():
229
+ try:
230
+ # Some redshift_connector connection objects set internal flags
231
+ # when closed; call close and then clear references.
232
+ self.connection.close()
233
+ except Exception:
234
+ # some drivers raise on close if already closed; ignore
235
+ pass
236
+ finally:
237
+ # Always clear connection and cursor references to avoid
238
+ # future probes erroneously believing a connection exists.
239
+ try:
240
+ self.cursor = None
241
+ except Exception:
242
+ pass
243
+ try:
244
+ self.connection = None
245
+ except Exception:
246
+ pass
247
+ except Exception as e:
248
+ self._logger.warning(f"Error closing connection: {e}")
249
+
250
+ def __del__(self):
251
+ """Ensure the database connection is closed when the instance is garbage-collected."""
252
+ try:
253
+ self.close_redshift_connection()
254
+ except Exception:
255
+ pass
256
+
257
+ def __enter__(self):
258
+ """Enter the context manager, establishing a connection."""
259
+ self.connect()
260
+ return self
261
+
262
+ def __exit__(self, exc_type, exc_val, exc_tb):
263
+ """Exit the context manager, closing the connection."""
264
+ self.close_redshift_connection()
265
+
266
+ def _ensure_connected(self):
267
+ """Ensure there is an active connection, creating one if necessary."""
268
+ with self._lock:
269
+ if not self._is_connection_open():
270
+ self.connect()
271
+ else:
272
+ # If we are using an existing connection, reset its idle timer
273
+ self._start_idle_timer()
274
+
275
+ def _resolve_query(self, query: str, **kwargs) -> str:
276
+ """Resolve a query string: if it looks like a SQL file path, load it; otherwise return as-is."""
277
+ if query and query.strip().endswith(".sql"):
278
+ loaded = load_sql_query(query.strip(), **kwargs)
279
+ if loaded is None:
280
+ log_and_raise_error(self._logger, f"Could not load SQL file: {query}")
281
+ self._logger.info(f"Loaded SQL from file: {query}")
282
+ return loaded
283
+ return query
284
+
285
+ def execute_sql(self, query: str, fetch_result: bool = False, fetch_all: bool = False, **kwargs):
286
+ """
287
+ Execute a SQL query with automatic connection management and activity tracking.
288
+
289
+ This method is useful for executing DDL, DML, or queries that don't return
290
+ data (or when you need to fetch results). It automatically:
291
+ - Ensures the connection is active before executing
292
+ - Resets the idle timer to prevent connection timeout
293
+ - Optionally fetches and returns results
294
+
295
+ For queries that return datasets (SELECT), prefer using the sql() method
296
+ which returns pandas or polars DataFrames.
297
+
298
+ Parameters
299
+ ----------
300
+ query : str
301
+ The SQL query to execute, or a path to a .sql file (relative to project root).
302
+ If a .sql file path is provided, its contents are loaded automatically.
303
+ fetch_result : bool, optional
304
+ If True, fetches and returns a single row result using fetchone().
305
+ Defaults to False.
306
+ fetch_all : bool, optional
307
+ If True, fetches and returns all rows using fetchall().
308
+ Takes precedence over fetch_result if both are True.
309
+ Defaults to False.
310
+ **kwargs
311
+ Template variables to substitute in the SQL file using str.format().
312
+
313
+ Returns
314
+ -------
315
+ tuple, list of tuples, or None
316
+ - If fetch_all=True: list of row tuples
317
+ - If fetch_result=True: single row tuple
318
+ - Otherwise: None
319
+
320
+ Examples
321
+ --------
322
+ # Execute a DDL command
323
+ dc.execute_sql("CREATE TABLE test (id INT, name VARCHAR(100))")
324
+
325
+ # Execute from a SQL file
326
+ dc.execute_sql("queries/create_table.sql")
327
+
328
+ # Execute query and fetch single result
329
+ result = dc.execute_sql("SELECT COUNT(*) FROM test", fetch_result=True)
330
+ count = result[0] # Get the count value
331
+
332
+ # Execute query and fetch all results
333
+ results = dc.execute_sql("SELECT * FROM test", fetch_all=True)
334
+ for row in results:
335
+ print(row)
336
+ """
337
+ query = self._resolve_query(query, **kwargs)
338
+ self._ensure_connected()
339
+ self._mark_activity()
340
+ try:
341
+ self.cursor.execute(query)
342
+ if fetch_all:
343
+ return self.cursor.fetchall()
344
+ elif fetch_result:
345
+ return self.cursor.fetchone()
346
+ return None
347
+ finally:
348
+ self._start_idle_timer()
349
+
350
+ def sql(self, query: str = None, format: str = "pandas", **kwargs) -> pl.DataFrame | pd.DataFrame:
351
+ """
352
+ Execute a SQL query against the Redshift database and return the result.
353
+
354
+ Parameters
355
+ ----------
356
+ query : str
357
+ The SQL query to execute, or a path to a .sql file (relative to project root).
358
+ format : str
359
+ Output format: 'pandas' or 'polars'. Defaults to 'pandas'.
360
+ **kwargs
361
+ Template variables to substitute in the SQL file using str.format().
362
+ """
363
+ query = self._resolve_query(query, **kwargs)
364
+ self._ensure_connected()
365
+ self._mark_activity()
366
+ try:
367
+ if format not in ["pandas", "polars"]:
368
+ log_and_raise_error(self._logger, "Invalid format. Use 'pandas' or 'polars'.")
369
+
370
+ if format == "pandas":
371
+ self.execute_sql(query)
372
+ tmp = self.cursor.fetch_dataframe()
373
+ self._logger.info("Data fetched successfully")
374
+ return tmp
375
+ elif format == "polars":
376
+ tmp = pl.read_database(query, connection=self.cursor)
377
+ self._logger.info("Data fetched successfully")
378
+ return tmp
379
+ except Exception as e:
380
+ log_and_raise_error(self._logger, f"Error fetching data: {e}")
381
+ finally:
382
+ self._start_idle_timer() # Reset timer after operation
383
+
384
+ def _get_s3_for_bucket(self, bucket: str = None) -> S3Connector:
385
+ """Get an S3Connector for the specified bucket.
386
+
387
+ If bucket is None, uses the default bucket. Returns a cached connector
388
+ if available, otherwise creates and caches a new one.
389
+
390
+ This avoids mutating a shared S3Connector instance and ensures each
391
+ operation uses the correct bucket.
392
+ """
393
+ target_bucket = self._resolve_s3_bucket(bucket)
394
+
395
+ # Check cache first
396
+ if target_bucket in self._s3_connectors:
397
+ return self._s3_connectors[target_bucket]
398
+
399
+ # Create new connector and cache it
400
+ s3_logger_instance = logging.getLogger(f"S3 Connector ({target_bucket})")
401
+ s3_logger_instance.setLevel(logging.WARNING)
402
+ connector = S3Connector(bucket=target_bucket, log_level="WARNING")
403
+ self._s3_connectors[target_bucket] = connector
404
+
405
+ # Also set self.s3 to the default bucket's connector for backward compatibility
406
+ if target_bucket == self._s3_bucket:
407
+ self.s3 = connector
408
+
409
+ return connector
410
+
411
+ def _resolve_s3_bucket(self, bucket: str = None) -> str:
412
+ target_bucket = bucket or self._s3_bucket
413
+ if not target_bucket:
414
+ log_and_raise_error(
415
+ self._logger,
416
+ "No S3 bucket configured. Pass s3_bucket=... or set ML_ANALYTICS_S3_BUCKET.",
417
+ )
418
+ return target_bucket.rstrip("/").lstrip("/")
419
+
420
+ def create_spectrum_table(
421
+ self,
422
+ table: str,
423
+ schema: str,
424
+ relative_path: str,
425
+ partitions: list[tuple[str, str]] = None,
426
+ force_table_creation: bool = False,
427
+ sync_partitions_on_creation: bool = True,
428
+ return_query: bool = False,
429
+ s3_bucket: str = None,
430
+ ) -> str | None:
431
+ self._ensure_connected()
432
+ self._mark_activity()
433
+
434
+ # Get the appropriate S3 connector for this operation
435
+ working_s3 = self._get_s3_for_bucket(s3_bucket)
436
+ working_bucket = self._resolve_s3_bucket(s3_bucket)
437
+
438
+ if relative_path:
439
+ relative_path = relative_path.strip().lstrip("/")
440
+
441
+ s3_path = working_s3.get_path(relative_path=relative_path)
442
+
443
+ # Check if files exist in S3 before proceeding
444
+ try:
445
+ files_in_path = working_s3.list_files(prefix=relative_path, bucket=working_bucket)
446
+ if not files_in_path:
447
+ log_and_raise_error(
448
+ self._logger,
449
+ f"No files found at S3 path: {s3_path}. Please verify the path exists and contains data files.",
450
+ FileNotFoundError,
451
+ )
452
+ self._logger.info(f"Found {len(files_in_path)} file(s) in {s3_path}")
453
+ except FileNotFoundError:
454
+ raise
455
+ except Exception as e:
456
+ self._logger.warning(f"Could not verify files in S3 path {s3_path}: {e}")
457
+
458
+ try:
459
+ full_table_name = f"{schema}.{table}"
460
+ table_exists = self._spectrum_table_exists(table_name=table, schema_name=schema)
461
+ create_table_query = None
462
+ partition_search_depth = len(partitions) if partitions else 2
463
+
464
+ if table_exists:
465
+ if force_table_creation:
466
+ self._logger.info(
467
+ f"Table {full_table_name} exists and force_table_creation is True. Dropping table."
468
+ )
469
+ drop_query = f"DROP TABLE IF EXISTS {full_table_name};"
470
+ try:
471
+ self.execute_sql(drop_query)
472
+ except Exception as e:
473
+ log_and_raise_error(self._logger, f"Error dropping table {full_table_name}: {e}")
474
+ else:
475
+ self._logger.info(
476
+ f"Table {full_table_name} already exists and force_table_creation is False. Skipping creation."
477
+ )
478
+
479
+ if self._spectrum_table_exists(table_name=table, schema_name=schema) is False:
480
+ try:
481
+ if not relative_path.endswith("/"):
482
+ relative_path_for_listing = relative_path + "/"
483
+ else:
484
+ relative_path_for_listing = relative_path
485
+
486
+ original_relative_path = relative_path
487
+
488
+ parquet_file_key = None
489
+ candidates = [
490
+ relative_path_for_listing,
491
+ relative_path_for_listing.lstrip("/"),
492
+ f"{working_bucket}/{relative_path_for_listing}".lstrip("/"),
493
+ original_relative_path,
494
+ original_relative_path.rstrip("/"),
495
+ ]
496
+
497
+ for prefix_candidate in candidates:
498
+ try:
499
+ files_prefix = working_s3.list_files_in_prefix(prefix=prefix_candidate) or []
500
+ except Exception:
501
+ files_prefix = []
502
+ for file_key in files_prefix:
503
+ if file_key.lower().endswith(".parquet"):
504
+ parquet_file_key = file_key
505
+ self._logger.debug(
506
+ "Parquet discovered (prefix-list): %s (tried %s)",
507
+ parquet_file_key,
508
+ prefix_candidate,
509
+ )
510
+ break
511
+ if parquet_file_key:
512
+ break
513
+
514
+ if parquet_file_key is None:
515
+ for prefix_candidate in candidates:
516
+ try:
517
+ potential_partition_dirs = (
518
+ working_s3.list_partition_paths(
519
+ prefix=prefix_candidate, depth=partition_search_depth
520
+ )
521
+ or []
522
+ )
523
+ except Exception:
524
+ potential_partition_dirs = []
525
+ for partition_dir in potential_partition_dirs:
526
+ try:
527
+ files_in_partition = working_s3.list_files_in_prefix(prefix=partition_dir) or []
528
+ except Exception:
529
+ files_in_partition = []
530
+ for file_key_in_part in files_in_partition:
531
+ if file_key_in_part.lower().endswith(".parquet"):
532
+ parquet_file_key = file_key_in_part
533
+ self._logger.debug(
534
+ "Parquet discovered in partition: %s (tried %s)",
535
+ parquet_file_key,
536
+ partition_dir,
537
+ )
538
+ break
539
+ if parquet_file_key:
540
+ break
541
+ if parquet_file_key:
542
+ break
543
+
544
+ if parquet_file_key is None:
545
+ for prefix_candidate in candidates:
546
+ try:
547
+ files_candidate = working_s3.list_files(prefix=prefix_candidate) or []
548
+ except Exception:
549
+ files_candidate = []
550
+ for file_key in files_candidate:
551
+ if file_key.lower().endswith(".parquet"):
552
+ parquet_file_key = file_key
553
+ self._logger.debug(
554
+ "Parquet discovered recursively: %s (tried %s)",
555
+ parquet_file_key,
556
+ prefix_candidate,
557
+ )
558
+ break
559
+ if parquet_file_key:
560
+ break
561
+
562
+ if not parquet_file_key:
563
+ log_and_raise_error(
564
+ self._logger,
565
+ f"No .parquet file found in '{relative_path}' or its immediate subdirectories "
566
+ "for schema inference.",
567
+ FileNotFoundError,
568
+ )
569
+
570
+ parquet_key_for_path = parquet_file_key
571
+ if parquet_key_for_path.startswith(f"{working_bucket}/"):
572
+ parquet_key_for_path = parquet_key_for_path[len(working_bucket) + 1 :]
573
+ parquet_s3_path = working_s3.get_path(parquet_key_for_path)
574
+ df_schema = pq.read_schema(parquet_s3_path)
575
+ columns_with_types = self._convert_pyarrow_schema_to_sql(df_schema, partitions)
576
+ column_names = [col for col, _ in columns_with_types]
577
+ self._check_redshift_reserved_words(column_names)
578
+
579
+ except Exception as e:
580
+ log_and_raise_error(self._logger, f"Failed to infer schema or find Parquet file: {e}")
581
+
582
+ create_table_query = f"CREATE EXTERNAL TABLE {full_table_name} (\n"
583
+ create_table_query += ",\n".join([f" {col} {dtype}" for col, dtype in columns_with_types])
584
+ create_table_query += "\n)\n"
585
+
586
+ if partitions:
587
+ partition_cols_str = ", ".join([f"{p_name} {p_type}" for p_name, p_type in partitions])
588
+ create_table_query += f"PARTITIONED BY ({partition_cols_str})\n"
589
+
590
+ create_table_query += "STORED AS PARQUET\n"
591
+ create_table_query += f"LOCATION '{s3_path}';"
592
+
593
+ try:
594
+ self.execute_sql(create_table_query)
595
+ self._logger.info(f"Successfully created Spectrum table: {full_table_name}")
596
+ except Exception as e:
597
+ log_and_raise_error(self._logger, f"Error creating Spectrum table {full_table_name}: {e}")
598
+
599
+ if self._spectrum_table_exists(table_name=table, schema_name=schema):
600
+ if partitions and sync_partitions_on_creation:
601
+ partition_column_names = [p_name for p_name, p_type in partitions]
602
+ self.sync_spectrum_partitions(
603
+ table=table,
604
+ schema=schema,
605
+ relative_path=relative_path,
606
+ partitions_columns=partition_column_names,
607
+ s3_bucket=working_bucket,
608
+ )
609
+
610
+ if create_table_query and return_query:
611
+ return create_table_query
612
+ finally:
613
+ self._start_idle_timer()
614
+
615
+ def _spectrum_table_exists(self, table_name: str, schema_name: str) -> bool:
616
+ self._ensure_connected()
617
+ self._mark_activity()
618
+ try:
619
+ check_sql = f"""
620
+ SELECT EXISTS (
621
+ SELECT 1
622
+ FROM svv_external_tables
623
+ WHERE schemaname = '{schema_name}'
624
+ AND tablename = '{table_name}'
625
+ )
626
+ """
627
+ result = self.execute_sql(check_sql, fetch_result=True)
628
+ exists = result[0]
629
+ return exists
630
+ finally:
631
+ self._start_idle_timer()
632
+
633
+ def sync_spectrum_data(
634
+ self,
635
+ table: str,
636
+ schema: str,
637
+ relative_path: str,
638
+ partition_values: dict[str, str],
639
+ s3_bucket: str = None,
640
+ ):
641
+ self._ensure_connected()
642
+ self._mark_activity()
643
+ try:
644
+ # Get the appropriate S3 connector for this operation
645
+ working_s3 = self._get_s3_for_bucket(s3_bucket)
646
+ working_bucket = self._resolve_s3_bucket(s3_bucket)
647
+
648
+ fully_qualified_table = f"{schema}.{table}"
649
+ # Normalize the base relative path and build the partition suffix
650
+ if relative_path is None:
651
+ relative_path = ""
652
+ if not relative_path.endswith("/"):
653
+ relative_path_for_listing = relative_path + "/"
654
+ else:
655
+ relative_path_for_listing = relative_path
656
+
657
+ partition_suffix = "".join(f"{col}={val}/" for col, val in partition_values.items())
658
+
659
+ candidates = [
660
+ relative_path_for_listing + partition_suffix,
661
+ relative_path_for_listing.lstrip("/") + partition_suffix,
662
+ f"{working_bucket}/{relative_path_for_listing}{partition_suffix}".lstrip("/"),
663
+ ]
664
+
665
+ found_files = []
666
+ for candidate in candidates:
667
+ try:
668
+ files = working_s3.list_files(prefix=candidate) or []
669
+ except Exception:
670
+ files = []
671
+ if files:
672
+ found_files = files
673
+ self._logger.debug("Found %d files for partition candidate '%s'", len(files), candidate)
674
+ break
675
+
676
+ if not found_files:
677
+ # No files found for the requested partition; raise a clear error
678
+ log_and_raise_error(
679
+ self._logger,
680
+ f"No files found in S3 for partition values {partition_values} under '{relative_path}'.",
681
+ FileNotFoundError,
682
+ )
683
+
684
+ full_s3_location_for_partition = working_s3.get_path(relative_path)
685
+ if not full_s3_location_for_partition.endswith("/"):
686
+ full_s3_location_for_partition += "/"
687
+ for col, val in partition_values.items():
688
+ full_s3_location_for_partition += f"{col}={val}/"
689
+ if not full_s3_location_for_partition.endswith("/"):
690
+ full_s3_location_for_partition += "/"
691
+
692
+ partition_spec = ", ".join(f"{col}='{val}'" for col, val in partition_values.items())
693
+ query = f"""
694
+ ALTER TABLE {fully_qualified_table}
695
+ ADD IF NOT EXISTS PARTITION({partition_spec})
696
+ LOCATION '{full_s3_location_for_partition}'
697
+ """
698
+ self.execute_sql(query)
699
+ self._logger.info(f"Successful sync for table {fully_qualified_table}.")
700
+ except Exception as e:
701
+ log_and_raise_error(
702
+ self._logger, f"Error syncing partition {partition_values} for table {fully_qualified_table}: {e}"
703
+ )
704
+ finally:
705
+ self._start_idle_timer()
706
+
707
+ def sync_spectrum_partitions(
708
+ self, table: str, schema: str, relative_path: str, partitions_columns: list[str], s3_bucket: str = None
709
+ ):
710
+ self._ensure_connected()
711
+ self._mark_activity()
712
+ try:
713
+ # Get the appropriate S3 connector for this operation
714
+ working_s3 = self._get_s3_for_bucket(s3_bucket)
715
+
716
+ fully_qualified_table = f"{schema}.{table}"
717
+ current_relative_path = relative_path
718
+ if not current_relative_path.endswith("/"):
719
+ current_relative_path += "/"
720
+
721
+ base_s3_prefix = current_relative_path
722
+ discovered_partitions_s3_paths = working_s3.list_partition_paths(base_s3_prefix, len(partitions_columns))
723
+
724
+ if not discovered_partitions_s3_paths:
725
+ self._logger.info(
726
+ "No partition paths found under s3://%s/%s matching depth %d.",
727
+ working_s3.bucket,
728
+ base_s3_prefix,
729
+ len(partitions_columns),
730
+ )
731
+ return
732
+
733
+ alter_queries = []
734
+ for s3_partition_path in discovered_partitions_s3_paths:
735
+ if not s3_partition_path.startswith(base_s3_prefix):
736
+ self._logger.warning(
737
+ "Skipping path %s, does not start with base prefix %s.",
738
+ s3_partition_path,
739
+ base_s3_prefix,
740
+ )
741
+ continue
742
+
743
+ partition_key_value_str = s3_partition_path[len(base_s3_prefix) :].strip("/")
744
+ parts = partition_key_value_str.split("/")
745
+
746
+ if len(parts) != len(partitions_columns):
747
+ self._logger.warning(
748
+ "Skipping path %s, parts count (%d) != partition columns count (%d).",
749
+ s3_partition_path,
750
+ len(parts),
751
+ len(partitions_columns),
752
+ )
753
+ continue
754
+ partition_spec_parts = []
755
+ valid_partition = True
756
+ for i, part_col_name in enumerate(partitions_columns):
757
+ key_value = parts[i].split("=", 1)
758
+ if len(key_value) == 2 and key_value[0] == part_col_name:
759
+ partition_value_str = key_value[1].replace("'", "''")
760
+ partition_spec_parts.append(f"{part_col_name}='{partition_value_str}'")
761
+ else:
762
+ valid_partition = False
763
+ break
764
+ if not valid_partition:
765
+ continue
766
+ partition_spec = ", ".join(partition_spec_parts)
767
+ full_s3_location_for_partition = working_s3.get_path(s3_partition_path)
768
+ if not full_s3_location_for_partition.endswith("/"):
769
+ full_s3_location_for_partition += "/"
770
+ query = f"""
771
+ ALTER TABLE {fully_qualified_table}
772
+ ADD IF NOT EXISTS PARTITION({partition_spec})
773
+ LOCATION '{full_s3_location_for_partition}';
774
+ """
775
+ alter_queries.append((query, partition_spec))
776
+
777
+ if not alter_queries:
778
+ self._logger.info(f"No new partitions to add for {fully_qualified_table}.")
779
+ return
780
+
781
+ for query, spec_for_log in alter_queries:
782
+ try:
783
+ self.execute_sql(query)
784
+ except Exception as e:
785
+ if "already exists" in str(e).lower():
786
+ self._logger.info(
787
+ f"Partition ({spec_for_log}) likely already exists for {fully_qualified_table}."
788
+ )
789
+ else:
790
+ self._logger.error(
791
+ f"Error adding partition for {fully_qualified_table} with spec ({spec_for_log}): {e}"
792
+ )
793
+ break
794
+ else:
795
+ self._logger.info(f"Finished syncing partitions for {fully_qualified_table}.")
796
+ finally:
797
+ self._start_idle_timer()
798
+
799
+ def copy_table(self, source_table: str, destination_table: str, drop_destination_table: bool = True):
800
+ self._ensure_connected()
801
+ self._mark_activity()
802
+ try:
803
+ if not source_table or not destination_table:
804
+ log_and_raise_error(self._logger, "Source and target table names must be provided.")
805
+ if "." not in source_table or "." not in destination_table:
806
+ log_and_raise_error(
807
+ self._logger, "Source and target table names must include schema (e.g., 'schema.table')."
808
+ )
809
+ if source_table == destination_table:
810
+ log_and_raise_error(self._logger, "Source and target table names cannot be the same.")
811
+
812
+ source_table = source_table.strip()
813
+ destination_table = destination_table.strip()
814
+
815
+ if drop_destination_table:
816
+ self.execute_sql(f"DROP TABLE IF EXISTS {destination_table};")
817
+ create_table_query = f"""CREATE TABLE {destination_table} AS SELECT * FROM {source_table}"""
818
+ self.execute_sql(create_table_query)
819
+ else:
820
+ self.execute_sql(f"SELECT * FROM {source_table} LIMIT 0;")
821
+ insert_query = f"INSERT INTO {destination_table} SELECT * FROM {source_table};"
822
+ self.execute_sql(insert_query)
823
+ self._logger.info(f"Data copied successfully from {source_table} to {destination_table}.")
824
+ except Exception as e:
825
+ log_and_raise_error(self._logger, f"Error during table copy: {e}")
826
+ finally:
827
+ self._start_idle_timer()
828
+
829
+ def unload_to_s3(
830
+ self,
831
+ query: str,
832
+ relative_path: str,
833
+ file_prefix: str = "data",
834
+ s3_bucket: str = None,
835
+ parallel: bool = True,
836
+ overwrite: bool = True,
837
+ drop_existing_files: bool = False,
838
+ format: str = "PARQUET",
839
+ max_file_size: str = None,
840
+ partition_by: list[str] = None,
841
+ ):
842
+ """
843
+ Execute a Redshift UNLOAD command to export query results directly to S3.
844
+
845
+ Parameters
846
+ ----------
847
+ query : str
848
+ The SELECT query to unload. Accepts:
849
+ - A SELECT/WITH query string.
850
+ - A table name (will be wrapped in ``SELECT * FROM``).
851
+ - A path to a ``.sql`` file (relative to project root).
852
+ - A multi-statement SQL string separated by semicolons.
853
+
854
+ When multiple statements are provided (via file or inline string),
855
+ all preceding statements are executed first (e.g., CREATE TEMP TABLE,
856
+ INSERT, etc.) and the **last** statement is used as the UNLOAD query.
857
+ relative_path : str
858
+ The relative path within the S3 bucket where files will be saved (e.g., 'my-data/output/').
859
+ file_prefix : str, optional
860
+ Prefix for the output files. Defaults to 'data'.
861
+ s3_bucket : str, optional
862
+ The S3 bucket name. Defaults to the instance's bucket.
863
+ parallel : bool, optional
864
+ If True, uses PARALLEL ON for faster unload with multiple files (one per slice).
865
+ If False, uses PARALLEL OFF for a single output file. Defaults to True.
866
+ Note: When using PARTITION BY, Redshift may create multiple files per partition
867
+ even with PARALLEL OFF, as it parallelizes within each partition.
868
+ overwrite : bool, optional
869
+ If True, adds ALLOWOVERWRITE to replace existing files. Defaults to True.
870
+ drop_existing_files : bool, optional
871
+ If True, deletes all existing files matching the prefix before UNLOAD.
872
+ This ensures a clean output directory. Defaults to False.
873
+ Note: This happens before UNLOAD, regardless of the overwrite setting.
874
+ format : str, optional
875
+ Output format: 'PARQUET', 'CSV', or 'JSON'. Defaults to 'PARQUET'.
876
+ max_file_size : str, optional
877
+ Maximum size per file (e.g., '100 MB', '1 GB'). Only valid with PARALLEL ON.
878
+ Causes Redshift to split files larger than this size. Use this to control file sizes
879
+ when you have large datasets.
880
+ partition_by : list[str], optional
881
+ List of column names to partition by (Parquet only). Example: ['year', 'month'].
882
+
883
+ Returns
884
+ -------
885
+ None
886
+
887
+ Examples
888
+ --------
889
+ # Simple unload (multiple files, one per cluster slice)
890
+ dc.unload_to_s3(
891
+ query="SELECT * FROM my_schema.my_table WHERE date >= '2024-01-01'",
892
+ relative_path="exports/my_table/",
893
+ file_prefix="my_table_2024"
894
+ )
895
+
896
+ # Single file output
897
+ dc.unload_to_s3(
898
+ query="my_schema.summary_table",
899
+ relative_path="exports/summary/",
900
+ file_prefix="summary",
901
+ parallel=False
902
+ )
903
+
904
+ # Control file size (will create more files if data exceeds max_file_size)
905
+ dc.unload_to_s3(
906
+ query="SELECT * FROM my_schema.large_table",
907
+ relative_path="exports/large_table/",
908
+ file_prefix="large",
909
+ max_file_size="500 MB"
910
+ )
911
+
912
+ # Unload with partitioning
913
+ dc.unload_to_s3(
914
+ query="SELECT * FROM my_schema.events",
915
+ relative_path="exports/events/",
916
+ file_prefix="events",
917
+ partition_by=["year", "month"]
918
+ )
919
+
920
+ # Unload from a .sql file (last SELECT is used for UNLOAD)
921
+ dc.unload_to_s3(
922
+ query="sql/my_export_query.sql",
923
+ relative_path="exports/my_table/",
924
+ file_prefix="my_table"
925
+ )
926
+
927
+ # Multi-statement SQL (preceding statements run first, last used for UNLOAD)
928
+ dc.unload_to_s3(
929
+ query=\"\"\"
930
+ CREATE TEMP TABLE tmp AS SELECT id, name FROM users WHERE active;
931
+ SELECT * FROM tmp
932
+ \"\"\",
933
+ relative_path="exports/active_users/",
934
+ file_prefix="active_users"
935
+ )
936
+
937
+ Notes
938
+ -----
939
+ - With PARALLEL ON: Creates one file per cluster slice (typically 2-32 files).
940
+ - With PARALLEL OFF: Creates a single file, UNLESS using PARTITION BY.
941
+ - With PARTITION BY: Redshift parallelizes within each partition, creating multiple
942
+ files per partition regardless of the PARALLEL setting. This is a Redshift limitation.
943
+ - To control file sizes with large datasets, use the max_file_size parameter.
944
+ """
945
+
946
+ self._ensure_connected()
947
+ self._mark_activity()
948
+
949
+ try:
950
+ # Get the appropriate S3 connector for this operation
951
+ working_bucket = self._resolve_s3_bucket(s3_bucket)
952
+ working_s3 = self._get_s3_for_bucket(working_bucket)
953
+
954
+ if relative_path:
955
+ relative_path = relative_path.strip().lstrip("/")
956
+ # Remove double slashes
957
+ relative_path = re.sub(r"/+", "/", relative_path)
958
+ if not relative_path.endswith("/"):
959
+ relative_path += "/"
960
+ else:
961
+ relative_path = ""
962
+
963
+ s3_path = f"s3://{working_bucket}/{relative_path}{file_prefix}"
964
+
965
+ format = format.upper()
966
+ if format not in ["PARQUET", "CSV", "JSON"]:
967
+ log_and_raise_error(self._logger, f"Invalid format '{format}'. Must be PARQUET, CSV, or JSON.")
968
+
969
+ if partition_by and format != "PARQUET":
970
+ log_and_raise_error(self._logger, "partition_by parameter is only supported for PARQUET format.")
971
+
972
+ try:
973
+ existing_files = working_s3.list_files(prefix=f"{relative_path}{file_prefix}", bucket=working_bucket)
974
+ if existing_files:
975
+ if drop_existing_files:
976
+ self._logger.info(
977
+ f"Found {len(existing_files)} existing file(s) at {s3_path}. "
978
+ "Deleting them (drop_existing_files=True)..."
979
+ )
980
+ for file_key in existing_files:
981
+ try:
982
+ working_s3.delete_file(file_key, bucket=working_bucket)
983
+ except Exception as delete_error:
984
+ self._logger.warning(f"Failed to delete {file_key}: {delete_error}")
985
+ self._logger.info(f"Deleted {len(existing_files)} file(s).")
986
+ elif overwrite:
987
+ self._logger.warning(
988
+ f"Found {len(existing_files)} existing file(s) at {s3_path}. "
989
+ "They will be overwritten (overwrite=True)."
990
+ )
991
+ else:
992
+ self._logger.warning(
993
+ f"Found {len(existing_files)} existing file(s) at {s3_path}. "
994
+ "UNLOAD will fail unless overwrite=True is set."
995
+ )
996
+ except Exception as e:
997
+ self._logger.debug(f"Could not check for existing files: {e}")
998
+
999
+ session = boto3.Session()
1000
+ credentials = session.get_credentials()
1001
+ if credentials is None:
1002
+ log_and_raise_error(
1003
+ self._logger,
1004
+ "Unable to retrieve AWS credentials. Ensure boto3 is configured correctly.",
1005
+ )
1006
+
1007
+ aws_access_key = credentials.access_key
1008
+ aws_secret_key = credentials.secret_key
1009
+ aws_session_token = credentials.token
1010
+
1011
+ # Clean and validate query
1012
+ if query is None:
1013
+ log_and_raise_error(
1014
+ self._logger,
1015
+ "Query cannot be None. Please provide a valid SQL query string or ensure the SQL file exists.",
1016
+ )
1017
+ query = self._resolve_query(query)
1018
+
1019
+ # Handle multi-statement SQL: split, execute preceding statements,
1020
+ # keep the last one as the UNLOAD query.
1021
+ statements = _split_sql_statements(query)
1022
+ if len(statements) > 1:
1023
+ preceding = statements[:-1]
1024
+ query = statements[-1]
1025
+ self._logger.info(
1026
+ f"Multi-statement SQL detected: executing {len(preceding)} preceding statement(s) before UNLOAD"
1027
+ )
1028
+ for i, stmt in enumerate(preceding):
1029
+ self._logger.debug(f"Executing preceding statement {i + 1}: {stmt[:100]}...")
1030
+ self.execute_sql(stmt)
1031
+ elif len(statements) == 1:
1032
+ query = statements[0]
1033
+
1034
+ # Check if it's a query statement (SELECT, WITH, etc.) or a table name
1035
+ # Skip leading comments to check actual SQL statement
1036
+ lines = query.split("\n")
1037
+ first_sql_line = None
1038
+ for line in lines:
1039
+ stripped = line.strip()
1040
+ # Skip empty lines and comments
1041
+ if stripped and not stripped.startswith("--"):
1042
+ first_sql_line = stripped.upper()
1043
+ break
1044
+
1045
+ # If no SQL found or doesn't look like a query, treat as table name
1046
+ if first_sql_line is None or not (
1047
+ first_sql_line.startswith("SELECT")
1048
+ or first_sql_line.startswith("WITH")
1049
+ or first_sql_line.startswith("(")
1050
+ ):
1051
+ query = f"SELECT * FROM {query}"
1052
+
1053
+ # Build UNLOAD query
1054
+ unload_query = f"""
1055
+ UNLOAD ($$
1056
+ {query}
1057
+ $$)
1058
+ TO '{s3_path}'
1059
+ ACCESS_KEY_ID '{aws_access_key}'
1060
+ SECRET_ACCESS_KEY '{aws_secret_key}'
1061
+ """
1062
+
1063
+ if aws_session_token:
1064
+ unload_query += f"SESSION_TOKEN '{aws_session_token}'\n"
1065
+
1066
+ # Add format
1067
+ if format == "PARQUET":
1068
+ unload_query += "FORMAT AS PARQUET\n"
1069
+ elif format == "CSV":
1070
+ unload_query += "FORMAT AS CSV\n"
1071
+ elif format == "JSON":
1072
+ unload_query += "FORMAT AS JSON\n"
1073
+
1074
+ if partition_by:
1075
+ partition_cols = ", ".join(partition_by)
1076
+ unload_query += f"PARTITION BY ({partition_cols})\n"
1077
+ if not parallel:
1078
+ self._logger.warning(
1079
+ "Using PARTITION BY with PARALLEL OFF may still create multiple files per partition. "
1080
+ "Redshift parallelizes within each partition even when PARALLEL OFF is specified."
1081
+ )
1082
+
1083
+ if parallel:
1084
+ unload_query += "PARALLEL ON\n"
1085
+ else:
1086
+ unload_query += "PARALLEL OFF\n"
1087
+
1088
+ if max_file_size and parallel:
1089
+ unload_query += f"MAXFILESIZE {max_file_size}\n"
1090
+ elif max_file_size and not parallel:
1091
+ self._logger.warning("MAXFILESIZE is ignored when PARALLEL is OFF")
1092
+
1093
+ if overwrite:
1094
+ unload_query += "ALLOWOVERWRITE\n"
1095
+
1096
+ self._logger.info(f"Starting UNLOAD to {s3_path}")
1097
+ self._logger.debug(f"UNLOAD query: {unload_query}")
1098
+
1099
+ with self._lock:
1100
+ self._cancel_idle_timer()
1101
+ self._last_activity = time.time()
1102
+ self._ensure_connected()
1103
+ self.cursor.execute(unload_query)
1104
+ self._mark_activity()
1105
+
1106
+ self._logger.info(f"Successfully unloaded data to {s3_path}")
1107
+
1108
+ except Exception as e:
1109
+ log_and_raise_error(self._logger, f"Error during UNLOAD to S3: {e}")
1110
+ finally:
1111
+ self._start_idle_timer()
1112
+
1113
+ def load_from_s3(
1114
+ self,
1115
+ table: str,
1116
+ schema: str,
1117
+ relative_path: str,
1118
+ s3_bucket: str = None,
1119
+ format: str = "PARQUET",
1120
+ truncate_before_load: bool = False,
1121
+ drop_existing_table: bool = False,
1122
+ column_list: list[str] = None,
1123
+ column_types: dict[str, str] = None,
1124
+ ignore_header: int = None,
1125
+ delimiter: str = None,
1126
+ date_format: str = "auto",
1127
+ time_format: str = "auto",
1128
+ blank_as_null: bool = True,
1129
+ empty_as_null: bool = True,
1130
+ null_as: str = None,
1131
+ accept_invalid_chars: bool = False,
1132
+ max_error: int = 0,
1133
+ stat_update: bool = True,
1134
+ compupdate: bool = True,
1135
+ identity_column: str | bool = None,
1136
+ ):
1137
+ """
1138
+ Load data from S3 files into a Redshift table using the COPY command.
1139
+
1140
+ Parameters
1141
+ ----------
1142
+ table : str
1143
+ The name of the target table (without schema).
1144
+ schema : str
1145
+ The schema name where the table exists or will be created.
1146
+ relative_path : str
1147
+ The relative path within the S3 bucket where files are located (e.g., 'my-data/input/').
1148
+ Can include wildcards or manifest files.
1149
+ s3_bucket : str, optional
1150
+ The S3 bucket name. Defaults to the instance's bucket.
1151
+ format : str, optional
1152
+ Input format: 'PARQUET', 'CSV', 'JSON', 'AVRO', or 'ORC'. Defaults to 'PARQUET'.
1153
+ truncate_before_load : bool, optional
1154
+ If True, truncates the table before loading. Defaults to False.
1155
+ drop_existing_table : bool, optional
1156
+ If True, drops the existing table and recreates it from the source schema.
1157
+ Only works with PARQUET format (auto-infers schema from Parquet files).
1158
+ For other formats, you must create the table manually first.
1159
+ Defaults to False. Note: This is more destructive than truncate_before_load.
1160
+ column_list : list[str], optional
1161
+ List of column names in the target table to load data into.
1162
+ If None, assumes columns match the file structure.
1163
+ column_types : dict[str, str], optional
1164
+ Override specific column types when auto-creating tables (with drop_existing_table=True).
1165
+ Useful for null-only columns or forcing specific types.
1166
+ Example: {'created_date': 'DATE', 'status': 'VARCHAR(50)', 'amount': 'DECIMAL(10,2)'}
1167
+ Note: Only applies when the table is auto-created from PARQUET files.
1168
+ identity_column : str | bool, optional
1169
+ Add an auto-incrementing IDENTITY column when creating the table (requires drop_existing_table=True).
1170
+ - If True: Creates a column named 'id' as BIGINT IDENTITY(1,1)
1171
+ - If str: Creates a column with the specified name as BIGINT IDENTITY(1,1)
1172
+ - If None/False: No identity column (default)
1173
+ The identity column is added as the first column and excluded from the COPY column list.
1174
+ Example: identity_column=True or identity_column='row_id'
1175
+ ignore_header : int, optional
1176
+ Number of header lines to ignore (CSV only). Example: 1 for single header row.
1177
+ delimiter : str, optional
1178
+ Field delimiter for CSV files. Defaults to comma if not specified.
1179
+ date_format : str, optional
1180
+ Date format string. Defaults to 'auto'. Examples: 'YYYY-MM-DD', 'auto'.
1181
+ Note: Only supported for CSV and JSON formats.
1182
+ time_format : str, optional
1183
+ Timestamp format string. Defaults to 'auto'. Examples: 'YYYY-MM-DD HH:MI:SS', 'auto'.
1184
+ Note: Only supported for CSV and JSON formats.
1185
+ blank_as_null : bool, optional
1186
+ If True, treats blank values as NULL. Defaults to True.
1187
+ Note: Only supported for CSV format.
1188
+ empty_as_null : bool, optional
1189
+ If True, treats empty strings as NULL. Defaults to True.
1190
+ Note: Only supported for CSV format.
1191
+ null_as : str, optional
1192
+ String to interpret as NULL (e.g., 'NULL', '\\N'). Defaults to None.
1193
+ Note: Only supported for CSV format.
1194
+ accept_invalid_chars : bool, optional
1195
+ If True, replaces invalid UTF-8 characters with '?'. Defaults to False.
1196
+ Note: Only supported for CSV and JSON formats.
1197
+ max_error : int, optional
1198
+ Maximum number of errors allowed before failing. Defaults to 0 (no errors allowed).
1199
+ stat_update : bool, optional
1200
+ If True, updates table statistics after load. Defaults to True.
1201
+ compupdate : bool, optional
1202
+ If True, updates compression encodings. Defaults to True.
1203
+ Note: Only supported for CSV, JSON, AVRO, and ORC formats (not PARQUET).
1204
+
1205
+ Returns
1206
+ -------
1207
+ str
1208
+ The fully qualified table name that was loaded.
1209
+
1210
+ Examples
1211
+ --------
1212
+ # Simple Parquet load
1213
+ dc.load_from_s3(
1214
+ table="my_table",
1215
+ schema="my_schema",
1216
+ relative_path="imports/my_table/"
1217
+ )
1218
+
1219
+ # CSV load with options
1220
+ dc.load_from_s3(
1221
+ table="my_table",
1222
+ schema="my_schema",
1223
+ relative_path="imports/my_table.csv",
1224
+ format="CSV",
1225
+ ignore_header=1,
1226
+ delimiter="|",
1227
+ truncate_before_load=True
1228
+ )
1229
+
1230
+ # Auto-create table from Parquet if missing
1231
+ dc.load_from_s3(
1232
+ table="new_table",
1233
+ schema="my_schema",
1234
+ relative_path="imports/data/"
1235
+ )
1236
+
1237
+ # Load specific columns
1238
+ dc.load_from_s3(
1239
+ table="my_table",
1240
+ schema="my_schema",
1241
+ relative_path="imports/data/",
1242
+ column_list=["id", "name", "created_at"]
1243
+ )
1244
+ """
1245
+ import boto3
1246
+
1247
+ self._ensure_connected()
1248
+ self._mark_activity()
1249
+
1250
+ # Get the appropriate S3 connector for this operation
1251
+ working_bucket = self._resolve_s3_bucket(s3_bucket)
1252
+ working_s3 = self._get_s3_for_bucket(working_bucket)
1253
+
1254
+ if relative_path:
1255
+ relative_path = relative_path.strip().lstrip("/")
1256
+ # Remove double slashes
1257
+ relative_path = re.sub(r"/+", "/", relative_path)
1258
+ else:
1259
+ log_and_raise_error(self._logger, "relative_path cannot be empty")
1260
+
1261
+ s3_path = f"s3://{working_bucket}/{relative_path}"
1262
+
1263
+ format = format.upper()
1264
+ if format not in ["PARQUET", "CSV", "JSON", "AVRO", "ORC"]:
1265
+ log_and_raise_error(self._logger, f"Invalid format '{format}'. Must be PARQUET, CSV, JSON, AVRO, or ORC.")
1266
+
1267
+ # Check if files exist in S3 before proceeding
1268
+ files_in_path = None
1269
+ try:
1270
+ files_in_path = working_s3.list_files(prefix=relative_path, bucket=working_bucket)
1271
+ if not files_in_path:
1272
+ log_and_raise_error(
1273
+ self._logger,
1274
+ f"No files found at S3 path: {s3_path}. Please verify the path exists and contains data files.",
1275
+ )
1276
+ self._logger.debug(f"Found {len(files_in_path)} file(s) in {s3_path}")
1277
+ except ValueError:
1278
+ # Re-raise errors from log_and_raise_error
1279
+ raise
1280
+ except Exception as e:
1281
+ self._logger.warning(f"Could not verify files in S3 path {s3_path}: {e}")
1282
+
1283
+ try:
1284
+ fully_qualified_table = f"{schema}.{table}"
1285
+
1286
+ table_exists = self._table_exists(table_name=table, schema_name=schema)
1287
+
1288
+ # Handle table creation/recreation logic
1289
+ if not table_exists:
1290
+ if format == "PARQUET":
1291
+ # If column_types provided, create table directly from those types
1292
+ # Otherwise infer from Parquet schema
1293
+ if column_types:
1294
+ self._logger.debug(
1295
+ f"Table {fully_qualified_table} does not exist. Creating from column_types..."
1296
+ )
1297
+ # Use all columns from column_types (not just those in Parquet)
1298
+ # Some columns may be all-null and not present in Parquet file
1299
+ column_order = list(column_types.keys())
1300
+ self._create_table_from_column_types(
1301
+ schema=schema,
1302
+ table=table,
1303
+ column_types=column_types,
1304
+ column_order=column_order,
1305
+ identity_column=identity_column,
1306
+ )
1307
+ else:
1308
+ self._logger.debug(
1309
+ f"Table {fully_qualified_table} does not exist. Creating from Parquet schema..."
1310
+ )
1311
+ self._create_table_from_parquet(
1312
+ table=table,
1313
+ schema=schema,
1314
+ s3_bucket=working_bucket,
1315
+ relative_path=relative_path,
1316
+ known_files=files_in_path,
1317
+ column_types=column_types,
1318
+ identity_column=identity_column,
1319
+ )
1320
+ else:
1321
+ log_and_raise_error(
1322
+ self._logger,
1323
+ f"Table {fully_qualified_table} does not exist. "
1324
+ f"Auto-creation only supported for PARQUET format. "
1325
+ f"Please create the table manually for {format} format.",
1326
+ )
1327
+ elif drop_existing_table:
1328
+ self._logger.debug(f"Dropping table {fully_qualified_table}")
1329
+ with self._lock:
1330
+ self._cancel_idle_timer()
1331
+ self._last_activity = time.time()
1332
+ self._ensure_connected()
1333
+ self.cursor.execute(f"DROP TABLE IF EXISTS {fully_qualified_table}")
1334
+ self._start_idle_timer()
1335
+
1336
+ if format == "PARQUET":
1337
+ # If column_types provided, create table directly from those types
1338
+ # Otherwise infer from Parquet schema
1339
+ if column_types:
1340
+ self._logger.debug(f"Recreating table {fully_qualified_table} from column_types")
1341
+ # Use all columns from column_types (not just those in Parquet)
1342
+ # Some columns may be all-null and not present in Parquet file
1343
+ column_order = list(column_types.keys())
1344
+ self._create_table_from_column_types(
1345
+ schema=schema,
1346
+ table=table,
1347
+ column_types=column_types,
1348
+ column_order=column_order,
1349
+ identity_column=identity_column,
1350
+ )
1351
+ else:
1352
+ self._logger.debug(f"Recreating table {fully_qualified_table} from Parquet schema")
1353
+ self._create_table_from_parquet(
1354
+ schema=schema,
1355
+ table=table,
1356
+ s3_bucket=working_bucket,
1357
+ relative_path=relative_path,
1358
+ known_files=files_in_path,
1359
+ column_types=column_types,
1360
+ identity_column=identity_column,
1361
+ )
1362
+ else:
1363
+ log_and_raise_error(
1364
+ self._logger,
1365
+ f"Cannot auto-create table from {format} format. "
1366
+ "Auto-creation only supported for PARQUET format. "
1367
+ "Create the table manually before using drop_existing_table with non-Parquet formats.",
1368
+ )
1369
+
1370
+ # Truncate if requested (only if not already dropped)
1371
+ if truncate_before_load and not drop_existing_table:
1372
+ self._logger.debug(f"Truncating table {fully_qualified_table}")
1373
+ with self._lock:
1374
+ self._cancel_idle_timer()
1375
+ self._last_activity = time.time()
1376
+ self._ensure_connected()
1377
+ self.cursor.execute(f"TRUNCATE TABLE {fully_qualified_table}")
1378
+ self._start_idle_timer()
1379
+
1380
+ session = boto3.Session()
1381
+ credentials = session.get_credentials()
1382
+ if credentials is None:
1383
+ log_and_raise_error(
1384
+ self._logger, "Unable to retrieve AWS credentials. Ensure boto3 is configured correctly."
1385
+ )
1386
+
1387
+ aws_access_key = credentials.access_key
1388
+ aws_secret_key = credentials.secret_key
1389
+ aws_session_token = credentials.token
1390
+
1391
+ # Check for identity columns that should be excluded from COPY
1392
+ identity_columns = []
1393
+ if table_exists and format == "PARQUET":
1394
+ try:
1395
+ # Use pg_catalog tables to detect identity columns via adsrc
1396
+ identity_query = f"""
1397
+ SELECT a.attname
1398
+ FROM pg_class c, pg_attribute a, pg_attrdef d, pg_namespace n
1399
+ WHERE c.oid = a.attrelid
1400
+ AND c.relkind = 'r'
1401
+ AND a.attrelid = d.adrelid
1402
+ AND a.attnum = d.adnum
1403
+ AND d.adsrc LIKE '%identity%'
1404
+ AND c.relnamespace = n.oid
1405
+ AND n.nspname = '{schema}'
1406
+ AND c.relname = '{table}'
1407
+ ORDER BY a.attnum
1408
+ """
1409
+ identity_results = self.execute_sql(identity_query, fetch_all=True)
1410
+ identity_columns = [row[0] for row in identity_results] if identity_results else []
1411
+ if identity_columns:
1412
+ self._logger.debug(f"Detected identity columns: {identity_columns}")
1413
+ except Exception as e:
1414
+ self._logger.warning(f"Could not check for identity columns: {e}")
1415
+
1416
+ # Build COPY query
1417
+ copy_query = f"COPY {fully_qualified_table}"
1418
+
1419
+ # Add column list if provided or if we need to exclude identity columns
1420
+ # NOTE: column_list is NOT typically used for PARQUET format in Redshift
1421
+ # PARQUET columns must match table columns in order
1422
+ # EXCEPTION: When identity columns exist, we must specify column list excluding them
1423
+ if format == "PARQUET" and identity_columns:
1424
+ # Get all table columns and exclude identity columns
1425
+ try:
1426
+ columns_query = f"""
1427
+ SELECT column_name
1428
+ FROM information_schema.columns
1429
+ WHERE table_schema = '{schema}'
1430
+ AND table_name = '{table}'
1431
+ ORDER BY ordinal_position
1432
+ """
1433
+ all_columns_results = self.execute_sql(columns_query, fetch_all=True)
1434
+ all_columns = [row[0] for row in all_columns_results] if all_columns_results else []
1435
+ # Exclude identity columns from COPY
1436
+ copy_columns = [col for col in all_columns if col not in identity_columns]
1437
+ columns_str = ", ".join(copy_columns)
1438
+ copy_query += f" ({columns_str})"
1439
+ self._logger.debug(f"COPY will use column list (excluding identity): {copy_columns}")
1440
+ except Exception as e:
1441
+ self._logger.warning(f"Could not build column list for COPY: {e}")
1442
+ elif column_list and format != "PARQUET":
1443
+ columns_str = ", ".join(column_list)
1444
+ copy_query += f" ({columns_str})"
1445
+
1446
+ copy_query += f"\nFROM '{s3_path}'\n"
1447
+ copy_query += f"ACCESS_KEY_ID '{aws_access_key}'\n"
1448
+ copy_query += f"SECRET_ACCESS_KEY '{aws_secret_key}'\n"
1449
+
1450
+ # Add session token if present
1451
+ if aws_session_token:
1452
+ copy_query += f"SESSION_TOKEN '{aws_session_token}'\n"
1453
+
1454
+ # Add format-specific options
1455
+ if format == "PARQUET":
1456
+ copy_query += "FORMAT AS PARQUET\n"
1457
+ elif format == "CSV":
1458
+ copy_query += "FORMAT AS CSV\n"
1459
+ if ignore_header:
1460
+ copy_query += f"IGNOREHEADER {ignore_header}\n"
1461
+ if delimiter:
1462
+ copy_query += f"DELIMITER '{delimiter}'\n"
1463
+ elif format == "JSON":
1464
+ copy_query += "FORMAT AS JSON 'auto'\n"
1465
+ elif format == "AVRO":
1466
+ copy_query += "FORMAT AS AVRO 'auto'\n"
1467
+ elif format == "ORC":
1468
+ copy_query += "FORMAT AS ORC\n"
1469
+
1470
+ if format in ("CSV", "JSON") and date_format:
1471
+ copy_query += f"DATEFORMAT '{date_format}'\n"
1472
+ if format in ("CSV", "JSON") and time_format:
1473
+ copy_query += f"TIMEFORMAT '{time_format}'\n"
1474
+
1475
+ if format == "CSV":
1476
+ if blank_as_null:
1477
+ copy_query += "BLANKSASNULL\n"
1478
+ if empty_as_null:
1479
+ copy_query += "EMPTYASNULL\n"
1480
+ if null_as:
1481
+ copy_query += f"NULL AS '{null_as}'\n"
1482
+
1483
+ if format in ("CSV", "JSON") and accept_invalid_chars:
1484
+ copy_query += "ACCEPTINVCHARS\n"
1485
+
1486
+ if max_error > 0:
1487
+ copy_query += f"MAXERROR {max_error}\n"
1488
+
1489
+ if stat_update:
1490
+ copy_query += "STATUPDATE ON\n"
1491
+ else:
1492
+ copy_query += "STATUPDATE OFF\n"
1493
+
1494
+ if format != "PARQUET":
1495
+ if compupdate:
1496
+ copy_query += "COMPUPDATE ON\n"
1497
+ else:
1498
+ copy_query += "COMPUPDATE OFF\n"
1499
+
1500
+ self._logger.debug(f"Starting COPY from {s3_path} to {fully_qualified_table}")
1501
+ self._logger.debug(f"COPY query: {copy_query}")
1502
+
1503
+ with self._lock:
1504
+ self._cancel_idle_timer()
1505
+ self._last_activity = time.time()
1506
+ self._ensure_connected()
1507
+ self.cursor.execute(copy_query)
1508
+ self._mark_activity()
1509
+
1510
+ self._logger.debug(f"Successfully copied data from {s3_path} to {fully_qualified_table}")
1511
+
1512
+ except Exception as e:
1513
+ log_and_raise_error(self._logger, f"Error during COPY from S3: {e}")
1514
+ finally:
1515
+ self._start_idle_timer()
1516
+
1517
+ def _table_exists(self, table_name: str, schema_name: str) -> bool:
1518
+ """Check if a regular (non-external) table exists in Redshift."""
1519
+ self._ensure_connected()
1520
+ self._mark_activity()
1521
+ try:
1522
+ check_sql = f"""
1523
+ SELECT EXISTS (
1524
+ SELECT 1
1525
+ FROM pg_tables
1526
+ WHERE schemaname = '{schema_name}'
1527
+ AND tablename = '{table_name}'
1528
+ )
1529
+ """
1530
+ result = self.execute_sql(check_sql, fetch_result=True)
1531
+ exists = result[0]
1532
+ return exists
1533
+ finally:
1534
+ self._start_idle_timer()
1535
+
1536
+ def _create_table_from_parquet(
1537
+ self,
1538
+ table: str,
1539
+ schema: str,
1540
+ s3_bucket: str,
1541
+ relative_path: str,
1542
+ known_files: list[str] = None,
1543
+ column_types: dict[str, str] = None,
1544
+ identity_column: str | bool = None,
1545
+ ):
1546
+ """
1547
+ Helper method to create a table by inferring schema from a Parquet file.
1548
+
1549
+ Parameters
1550
+ ----------
1551
+ column_types : dict[str, str], optional
1552
+ Override types for specific columns. Example: {'date_col': 'DATE', 'status': 'VARCHAR(50)'}
1553
+ identity_column : str | bool, optional
1554
+ Add an auto-incrementing IDENTITY column as the first column.
1555
+ """
1556
+ # Get the appropriate S3 connector for this operation
1557
+ working_s3 = self._get_s3_for_bucket(s3_bucket)
1558
+
1559
+ try:
1560
+ if relative_path:
1561
+ relative_path = relative_path.strip().lstrip("/")
1562
+ if not relative_path.endswith("/"):
1563
+ relative_path_for_listing = relative_path + "/"
1564
+ else:
1565
+ relative_path_for_listing = relative_path
1566
+
1567
+ # Find a Parquet file for schema inference
1568
+ parquet_file_key = None
1569
+
1570
+ # Use known files if provided (avoids redundant S3 listing)
1571
+ if known_files:
1572
+ for file_key in known_files:
1573
+ if file_key.lower().endswith(".parquet"):
1574
+ parquet_file_key = file_key
1575
+ self._logger.debug(f"Found Parquet file for schema inference: {parquet_file_key}")
1576
+ break
1577
+
1578
+ # Fall back to S3 listing only if no parquet found in known files
1579
+ if not parquet_file_key:
1580
+ candidates = [
1581
+ relative_path_for_listing,
1582
+ relative_path_for_listing.lstrip("/"),
1583
+ relative_path,
1584
+ ]
1585
+
1586
+ for prefix_candidate in candidates:
1587
+ try:
1588
+ files_prefix = working_s3.list_files_in_prefix(prefix=prefix_candidate) or []
1589
+ except Exception:
1590
+ files_prefix = []
1591
+ for file_key in files_prefix:
1592
+ if file_key.lower().endswith(".parquet"):
1593
+ parquet_file_key = file_key
1594
+ self._logger.debug(f"Found Parquet file for schema inference: {parquet_file_key}")
1595
+ break
1596
+ if parquet_file_key:
1597
+ break
1598
+
1599
+ if not parquet_file_key:
1600
+ log_and_raise_error(
1601
+ self._logger,
1602
+ f"No .parquet file found in '{relative_path}' for schema inference.",
1603
+ FileNotFoundError,
1604
+ )
1605
+
1606
+ parquet_key_for_path = parquet_file_key
1607
+ if parquet_key_for_path.startswith(f"{working_s3.bucket}/"):
1608
+ parquet_key_for_path = parquet_key_for_path[len(working_s3.bucket) + 1 :]
1609
+ parquet_s3_path = working_s3.get_path(parquet_key_for_path)
1610
+ df_schema = pq.read_schema(parquet_s3_path)
1611
+ columns_with_types = self._convert_pyarrow_schema_to_sql(
1612
+ df_schema, partition_defs=None, column_types=column_types
1613
+ )
1614
+ column_names = [col for col, _ in columns_with_types]
1615
+ self._check_redshift_reserved_words(column_names)
1616
+
1617
+ fully_qualified_table = f"{schema}.{table}"
1618
+ create_table_query = f"CREATE TABLE {fully_qualified_table} (\n"
1619
+
1620
+ # Add identity column first if requested
1621
+ if identity_column:
1622
+ id_col_name = "id" if identity_column is True else identity_column
1623
+
1624
+ # Check if identity column name already exists in Parquet columns
1625
+ if id_col_name.lower() in [col.lower() for col in column_names]:
1626
+ self._logger.warning(
1627
+ f"Column '{id_col_name}' already exists in Parquet schema. "
1628
+ f"Skipping identity column to avoid duplicate. "
1629
+ f"Use a different identity_column name or remove '{id_col_name}' from your DataFrame."
1630
+ )
1631
+ else:
1632
+ create_table_query += f" {id_col_name} BIGINT IDENTITY(1,1),\n"
1633
+ self._logger.debug(f"Adding identity column '{id_col_name}' to table {fully_qualified_table}")
1634
+
1635
+ create_table_query += ",\n".join([f" {col} {dtype}" for col, dtype in columns_with_types])
1636
+ create_table_query += "\n);"
1637
+
1638
+ self.execute_sql(create_table_query)
1639
+
1640
+ self._logger.debug(f"Successfully created table {fully_qualified_table} from Parquet schema")
1641
+
1642
+ except Exception as e:
1643
+ log_and_raise_error(self._logger, f"Failed to create table from Parquet schema: {e}")
1644
+
1645
+ def _find_parquet_file(self, s3_connector, relative_path: str, known_files: list[str] = None) -> str:
1646
+ """
1647
+ Find a Parquet file in the given path for schema reading.
1648
+
1649
+ Returns the S3 key of a parquet file.
1650
+ """
1651
+ if relative_path:
1652
+ relative_path = relative_path.strip().lstrip("/")
1653
+
1654
+ # If it's a direct file path ending in .parquet, use it
1655
+ if relative_path.lower().endswith(".parquet"):
1656
+ return relative_path
1657
+
1658
+ # Otherwise it's a directory, find a parquet file in it
1659
+ if not relative_path.endswith("/"):
1660
+ relative_path_for_listing = relative_path + "/"
1661
+ else:
1662
+ relative_path_for_listing = relative_path
1663
+
1664
+ # Use known files if provided
1665
+ if known_files:
1666
+ for file_key in known_files:
1667
+ if file_key.lower().endswith(".parquet"):
1668
+ return file_key
1669
+
1670
+ # Fall back to S3 listing
1671
+ candidates = [relative_path_for_listing, relative_path_for_listing.lstrip("/"), relative_path]
1672
+ for prefix_candidate in candidates:
1673
+ try:
1674
+ files_prefix = s3_connector.list_files_in_prefix(prefix=prefix_candidate) or []
1675
+ except Exception:
1676
+ files_prefix = []
1677
+ for file_key in files_prefix:
1678
+ if file_key.lower().endswith(".parquet"):
1679
+ return file_key
1680
+
1681
+ log_and_raise_error(
1682
+ self._logger,
1683
+ f"No .parquet file found in '{relative_path}' for schema reading.",
1684
+ FileNotFoundError,
1685
+ )
1686
+
1687
+ def _create_table_from_column_types(
1688
+ self,
1689
+ table: str,
1690
+ schema: str,
1691
+ column_types: dict[str, str],
1692
+ column_order: list[str] = None,
1693
+ identity_column: str | bool = None,
1694
+ ):
1695
+ """
1696
+ Create a table directly from column_types specification.
1697
+
1698
+ This is used when explicit column types are provided and we don't want to infer from Parquet.
1699
+
1700
+ Parameters
1701
+ ----------
1702
+ table : str
1703
+ Table name
1704
+ schema : str
1705
+ Schema name
1706
+ column_types : dict[str, str]
1707
+ Dictionary mapping column names to SQL types
1708
+ column_order : list[str], optional
1709
+ Order of columns. If None, uses sorted order of column_types keys
1710
+ identity_column : str | bool, optional
1711
+ Add an auto-incrementing IDENTITY column as the first column.
1712
+ """
1713
+ try:
1714
+ if column_order is None:
1715
+ column_order = sorted(column_types.keys())
1716
+
1717
+ # Check for reserved words
1718
+ self._check_redshift_reserved_words(column_order)
1719
+
1720
+ fully_qualified_table = f"{schema}.{table}"
1721
+ create_table_query = f"CREATE TABLE {fully_qualified_table} (\n"
1722
+
1723
+ # Add identity column first if requested
1724
+ if identity_column:
1725
+ id_col_name = "id" if identity_column is True else identity_column
1726
+
1727
+ # Check if identity column name already exists in column_types
1728
+ if id_col_name.lower() in [col.lower() for col in column_order]:
1729
+ self._logger.warning(
1730
+ f"Column '{id_col_name}' already exists in column_types. "
1731
+ f"Skipping identity column to avoid duplicate. "
1732
+ f"Use a different identity_column name or remove '{id_col_name}' from column_types."
1733
+ )
1734
+ else:
1735
+ create_table_query += f" {id_col_name} BIGINT IDENTITY(1,1),\n"
1736
+ self._logger.debug(f"Adding identity column '{id_col_name}' to table {fully_qualified_table}")
1737
+
1738
+ create_table_query += ",\n".join([f" {col} {column_types[col]}" for col in column_order])
1739
+ create_table_query += "\n);"
1740
+
1741
+ self.execute_sql(create_table_query)
1742
+
1743
+ self._logger.debug(f"Successfully created table {fully_qualified_table} from column_types")
1744
+
1745
+ except Exception as e:
1746
+ log_and_raise_error(self._logger, f"Failed to create table from column_types: {e}")
1747
+
1748
+ @staticmethod
1749
+ def _check_redshift_reserved_words(column_names):
1750
+ reserved_words = {
1751
+ "aes128",
1752
+ "aes256",
1753
+ "all",
1754
+ "allowoverwrite",
1755
+ "analyse",
1756
+ "analyze",
1757
+ "and",
1758
+ "any",
1759
+ "array",
1760
+ "as",
1761
+ "asc",
1762
+ "authorization",
1763
+ "backup",
1764
+ "between",
1765
+ "binary",
1766
+ "blanksasnull",
1767
+ "both",
1768
+ "by",
1769
+ "bzip2",
1770
+ "case",
1771
+ "cast",
1772
+ "check",
1773
+ "collate",
1774
+ "column",
1775
+ "constraint",
1776
+ "create",
1777
+ "credentials",
1778
+ "cross",
1779
+ "current_date",
1780
+ "current_time",
1781
+ "current_timestamp",
1782
+ "current_user",
1783
+ "current_user_id",
1784
+ "default",
1785
+ "deferrable",
1786
+ "deflate",
1787
+ "defrag",
1788
+ "delta",
1789
+ "delta32k",
1790
+ "desc",
1791
+ "disable",
1792
+ "distinct",
1793
+ "do",
1794
+ "else",
1795
+ "emptyasnull",
1796
+ "enable",
1797
+ "encode",
1798
+ "encrypt",
1799
+ "encryption",
1800
+ "end",
1801
+ "except",
1802
+ "explicit",
1803
+ "false",
1804
+ "for",
1805
+ "foreign",
1806
+ "freeze",
1807
+ "from",
1808
+ "full",
1809
+ "globaldict256",
1810
+ "globaldict64k",
1811
+ "grant",
1812
+ "group",
1813
+ "gzip",
1814
+ "having",
1815
+ "identity",
1816
+ "ignore",
1817
+ "ilike",
1818
+ "in",
1819
+ "initially",
1820
+ "inner",
1821
+ "intersect",
1822
+ "into",
1823
+ "is",
1824
+ "isnull",
1825
+ "join",
1826
+ "leading",
1827
+ "left",
1828
+ "like",
1829
+ "limit",
1830
+ "localtime",
1831
+ "localtimestamp",
1832
+ "lun",
1833
+ "luns",
1834
+ "lzo",
1835
+ "minus",
1836
+ "mostly13",
1837
+ "mostly32",
1838
+ "mostly8",
1839
+ "natural",
1840
+ "new",
1841
+ "not",
1842
+ "notnull",
1843
+ "null",
1844
+ "nulls",
1845
+ "off",
1846
+ "offline",
1847
+ "offset",
1848
+ "oid",
1849
+ "old",
1850
+ "on",
1851
+ "only",
1852
+ "open",
1853
+ "or",
1854
+ "order",
1855
+ "outer",
1856
+ "overlaps",
1857
+ "parallel",
1858
+ "partition",
1859
+ "percent",
1860
+ "permissions",
1861
+ "placing",
1862
+ "primary",
1863
+ "raw",
1864
+ "readratio",
1865
+ "recover",
1866
+ "references",
1867
+ "respect",
1868
+ "rejectlog",
1869
+ "resort",
1870
+ "restore",
1871
+ "right",
1872
+ "select",
1873
+ "session_user",
1874
+ "similar",
1875
+ "some",
1876
+ "sysdate",
1877
+ "system",
1878
+ "table",
1879
+ "tag",
1880
+ "tdes",
1881
+ "text255",
1882
+ "text32k",
1883
+ "then",
1884
+ "timestamp",
1885
+ "to",
1886
+ "top",
1887
+ "trailing",
1888
+ "true",
1889
+ "truncate",
1890
+ "unload",
1891
+ "user",
1892
+ "using",
1893
+ "verbose",
1894
+ "wallet",
1895
+ "when",
1896
+ "where",
1897
+ "with",
1898
+ }
1899
+ reserved_found = [col for col in column_names if col.lower() in reserved_words]
1900
+ if reserved_found:
1901
+ raise ValueError(
1902
+ f"The following column names are reserved words in Redshift and cannot be used: {reserved_found}"
1903
+ )
1904
+
1905
+ def _cast_null_columns_for_parquet(
1906
+ self, df: pd.DataFrame | pl.DataFrame, column_types: dict[str, str]
1907
+ ) -> pd.DataFrame | pl.DataFrame:
1908
+ """
1909
+ Cast null-only columns to appropriate types based on column_types specification.
1910
+
1911
+ This ensures the Parquet file schema matches the table schema when using column_types.
1912
+ PyArrow's 'null' type is incompatible with Redshift Spectrum, so we cast null columns
1913
+ to typed null columns (e.g., int64 with nulls, date with nulls).
1914
+ """
1915
+ # Map Redshift SQL types to pandas/polars types
1916
+ type_mapping = {
1917
+ "DATE": "datetime64[ns]",
1918
+ "TIMESTAMP": "datetime64[ns]",
1919
+ "BIGINT": "Int64",
1920
+ "INT": "Int32",
1921
+ "SMALLINT": "Int16",
1922
+ "DECIMAL": "float64", # Pandas doesn't have native decimal, use float
1923
+ "DOUBLE PRECISION": "float64",
1924
+ "REAL": "float32",
1925
+ "BOOLEAN": "bool",
1926
+ "VARCHAR": "object",
1927
+ "TEXT": "object",
1928
+ }
1929
+
1930
+ if isinstance(df, pl.DataFrame):
1931
+ # Polars
1932
+ polars_type_mapping = {
1933
+ "DATE": pl.Date,
1934
+ "TIMESTAMP": pl.Datetime,
1935
+ "BIGINT": pl.Int64,
1936
+ "INT": pl.Int32,
1937
+ "SMALLINT": pl.Int16,
1938
+ "DECIMAL": pl.Float64,
1939
+ "DOUBLE PRECISION": pl.Float64,
1940
+ "REAL": pl.Float32,
1941
+ "BOOLEAN": pl.Boolean,
1942
+ "VARCHAR": pl.Utf8,
1943
+ "TEXT": pl.Utf8,
1944
+ }
1945
+
1946
+ null_cols = [col for col in df.columns if df[col].dtype == pl.Null and col in column_types]
1947
+ if null_cols:
1948
+ self._logger.info(f"Casting {len(null_cols)} null column(s) to proper types: {null_cols}")
1949
+ for col in null_cols:
1950
+ sql_type = column_types[col]
1951
+ # Extract base type (e.g., "DECIMAL(10,2)" -> "DECIMAL", "VARCHAR(50)" -> "VARCHAR")
1952
+ base_type = sql_type.split("(")[0].strip()
1953
+ target_type = polars_type_mapping.get(base_type, pl.Utf8)
1954
+ df = df.with_columns(pl.col(col).cast(target_type).alias(col))
1955
+ else:
1956
+ # Pandas
1957
+ null_cols = [col for col in df.columns if df[col].isna().all() and col in column_types]
1958
+ if null_cols:
1959
+ self._logger.info(f"Casting {len(null_cols)} null column(s) to proper types: {null_cols}")
1960
+ for col in null_cols:
1961
+ sql_type = column_types[col]
1962
+ # Extract base type
1963
+ base_type = sql_type.split("(")[0].strip()
1964
+ target_type = type_mapping.get(base_type, "object")
1965
+
1966
+ # Handle special cases
1967
+ if base_type in ["DATE", "TIMESTAMP"]:
1968
+ df[col] = pd.to_datetime(df[col])
1969
+ else:
1970
+ df[col] = df[col].astype(target_type)
1971
+
1972
+ return df
1973
+
1974
+ def _convert_null_columns_to_string(self, df: pd.DataFrame | pl.DataFrame) -> pd.DataFrame | pl.DataFrame:
1975
+ """
1976
+ Convert columns with all null values to string type.
1977
+
1978
+ When a column contains only null values, PyArrow infers it as 'null' type,
1979
+ which Redshift Spectrum cannot read from Parquet files. This method converts
1980
+ such columns to string type to ensure compatibility.
1981
+
1982
+ Parameters
1983
+ ----------
1984
+ df : pd.DataFrame | pl.DataFrame
1985
+ The DataFrame to process.
1986
+
1987
+ Returns
1988
+ -------
1989
+ pd.DataFrame | pl.DataFrame
1990
+ The DataFrame with null-type columns converted to string.
1991
+ """
1992
+ if isinstance(df, pl.DataFrame):
1993
+ # For Polars, check for Null dtype
1994
+ null_columns = [col for col in df.columns if df[col].dtype == pl.Null]
1995
+ if null_columns:
1996
+ self._logger.info(f"Converting {len(null_columns)} all-null column(s) to String type: {null_columns}")
1997
+ df = df.with_columns([pl.col(col).cast(pl.String).alias(col) for col in null_columns])
1998
+ else:
1999
+ # For Pandas, check for columns where all values are null
2000
+ # and the inferred PyArrow type would be 'null'
2001
+ null_columns = []
2002
+ for col in df.columns:
2003
+ if df[col].isna().all():
2004
+ # Check if PyArrow would infer this as null type
2005
+ try:
2006
+ arrow_type = pa.Array.from_pandas(df[col]).type
2007
+ if pa.types.is_null(arrow_type):
2008
+ null_columns.append(col)
2009
+ except Exception:
2010
+ # If we can't determine the type, check if it's object with all None
2011
+ if df[col].dtype == object or pd.api.types.is_object_dtype(df[col]):
2012
+ null_columns.append(col)
2013
+
2014
+ if null_columns:
2015
+ self._logger.info(f"Converting {len(null_columns)} all-null column(s) to string type: {null_columns}")
2016
+ for col in null_columns:
2017
+ df[col] = df[col].astype("string")
2018
+
2019
+ return df
2020
+
2021
+ def _convert_pyarrow_schema_to_sql(
2022
+ self,
2023
+ arrow_schema: pa.Schema,
2024
+ partition_defs: list[tuple[str, str]] = None,
2025
+ column_types: dict[str, str] = None,
2026
+ ) -> list[tuple[str, str]]:
2027
+ """
2028
+ Convert PyArrow schema to SQL column definitions.
2029
+
2030
+ Parameters
2031
+ ----------
2032
+ arrow_schema : pa.Schema
2033
+ PyArrow schema to convert
2034
+ partition_defs : list[tuple[str, str]], optional
2035
+ List of partition column definitions (name, type)
2036
+ column_types : dict[str, str], optional
2037
+ Override types for specific columns. Useful for null-only columns
2038
+ or when you want specific types. Example: {'date_col': 'DATE', 'status': 'VARCHAR(50)'}
2039
+ """
2040
+ columns_with_types = []
2041
+ partition_column_names = [p_name for p_name, p_type in partition_defs] if partition_defs else []
2042
+ column_type_overrides = column_types or {}
2043
+
2044
+ for field in arrow_schema:
2045
+ col_name = field.name
2046
+ arrow_type = field.type
2047
+
2048
+ if col_name in partition_column_names:
2049
+ self._logger.debug(f"Column '{col_name}' is a partition column, skipping from main schema definition.")
2050
+ continue
2051
+
2052
+ sql_type = None
2053
+ if pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type):
2054
+ sql_type = "VARCHAR(65535)"
2055
+ elif pa.types.is_int64(arrow_type):
2056
+ sql_type = "BIGINT"
2057
+ elif pa.types.is_int32(arrow_type):
2058
+ sql_type = "INT"
2059
+ elif pa.types.is_int16(arrow_type) or pa.types.is_int8(arrow_type):
2060
+ sql_type = "SMALLINT"
2061
+ elif pa.types.is_float64(arrow_type):
2062
+ sql_type = "DOUBLE PRECISION"
2063
+ elif pa.types.is_float32(arrow_type) or pa.types.is_float16(arrow_type):
2064
+ sql_type = "REAL"
2065
+ elif pa.types.is_boolean(arrow_type):
2066
+ sql_type = "BOOLEAN"
2067
+ elif pa.types.is_date32(arrow_type) or pa.types.is_date64(arrow_type):
2068
+ sql_type = "DATE"
2069
+ elif pa.types.is_timestamp(arrow_type):
2070
+ sql_type = "TIMESTAMP"
2071
+ elif pa.types.is_decimal(arrow_type):
2072
+ sql_type = f"DECIMAL({arrow_type.precision}, {arrow_type.scale})"
2073
+ elif pa.types.is_binary(arrow_type) or pa.types.is_large_binary(arrow_type):
2074
+ sql_type = "VARBYTE(65535)"
2075
+ elif pa.types.is_null(arrow_type):
2076
+ # Columns with only null values get 'null' type in PyArrow.
2077
+ # Check if user provided an override, otherwise default to VARCHAR.
2078
+ if col_name in column_type_overrides:
2079
+ sql_type = column_type_overrides[col_name]
2080
+ self._logger.info(
2081
+ f"Column '{col_name}' has PyArrow 'null' type (all values are null). "
2082
+ f"Using user-specified type: {sql_type}"
2083
+ )
2084
+ else:
2085
+ sql_type = "VARCHAR(65535)"
2086
+ self._logger.info(
2087
+ f"Column '{col_name}' has PyArrow 'null' type (all values are null). "
2088
+ "Defaulting to VARCHAR(65535)."
2089
+ )
2090
+ else:
2091
+ self._logger.warning(
2092
+ f"Unsupported PyArrow type '{arrow_type}' for column '{col_name}'. Skipping column."
2093
+ )
2094
+ continue
2095
+
2096
+ # Apply user override if provided (takes precedence over inferred type)
2097
+ if col_name in column_type_overrides and not pa.types.is_null(arrow_type):
2098
+ sql_type = column_type_overrides[col_name]
2099
+ self._logger.info(f"Column '{col_name}' type overridden by user: {sql_type}")
2100
+
2101
+ columns_with_types.append((col_name, sql_type))
2102
+
2103
+ if not columns_with_types and not partition_column_names and arrow_schema.names:
2104
+ log_and_raise_error(
2105
+ self._logger,
2106
+ "No columns could be derived from the Parquet schema. All columns might be of unsupported types "
2107
+ "or are partition columns.",
2108
+ )
2109
+ elif (
2110
+ not columns_with_types
2111
+ and arrow_schema.names
2112
+ and all(name in partition_column_names for name in arrow_schema.names)
2113
+ ):
2114
+ self._logger.info(
2115
+ "All columns in Parquet schema are partition columns. "
2116
+ "Table will be created with only partition columns if defined."
2117
+ )
2118
+ elif not columns_with_types and not arrow_schema.names:
2119
+ log_and_raise_error(self._logger, "Parquet schema is empty. Cannot create table.")
2120
+ return columns_with_types
2121
+
2122
+ @staticmethod
2123
+ def infer_column_types_from_dataframe(
2124
+ df: pd.DataFrame | pl.DataFrame, overrides: dict[str, str] = None
2125
+ ) -> dict[str, str]:
2126
+ """
2127
+ Generate a column_types dictionary from a DataFrame for use with load/create methods.
2128
+
2129
+ This helper method infers Redshift SQL types from DataFrame column types and provides
2130
+ sensible defaults for nullable columns. You can override specific columns as needed.
2131
+
2132
+ Parameters
2133
+ ----------
2134
+ df : pd.DataFrame | pl.DataFrame
2135
+ The DataFrame to infer column types from
2136
+ overrides : dict[str, str], optional
2137
+ Override specific column types. These take precedence over inferred types.
2138
+ Example: {'created_date': 'DATE', 'status': 'VARCHAR(50)'}
2139
+
2140
+ Returns
2141
+ -------
2142
+ dict[str, str]
2143
+ Dictionary mapping column names to Redshift SQL types
2144
+
2145
+ Examples
2146
+ --------
2147
+ # Basic usage - infer all types
2148
+ df = pd.DataFrame({
2149
+ 'id': [1, 2, 3],
2150
+ 'name': ['Alice', 'Bob', None],
2151
+ 'amount': [10.5, 20.0, 30.5],
2152
+ 'created_at': pd.to_datetime(['2024-01-01', '2024-01-02', None])
2153
+ })
2154
+ column_types = DataConnector.infer_column_types_from_dataframe(df)
2155
+ # Returns: {'id': 'BIGINT', 'name': 'VARCHAR(65535)', 'amount': 'DOUBLE PRECISION',
2156
+ # 'created_at': 'TIMESTAMP'}
2157
+
2158
+ # With overrides for null-only or specific columns
2159
+ df = pd.DataFrame({
2160
+ 'id': [1, 2],
2161
+ 'future_date': [None, None] # All nulls
2162
+ })
2163
+ column_types = DataConnector.infer_column_types_from_dataframe(
2164
+ df,
2165
+ overrides={'future_date': 'DATE'} # Specify what the type should be
2166
+ )
2167
+ # Returns: {'id': 'BIGINT', 'future_date': 'DATE'}
2168
+
2169
+ # Use with create_table_from_dataframe
2170
+ dc = DataConnector()
2171
+ column_types = DataConnector.infer_column_types_from_dataframe(
2172
+ df,
2173
+ overrides={'birth_date': 'DATE', 'salary': 'DECIMAL(12,2)'}
2174
+ )
2175
+ dc.create_table_from_dataframe(df, 'employees', 'hr', column_types=column_types)
2176
+ """
2177
+ import numpy as np
2178
+
2179
+ column_types = {}
2180
+ overrides_dict = overrides or {}
2181
+
2182
+ # Determine if it's pandas or polars
2183
+ if isinstance(df, pd.DataFrame):
2184
+ for col in df.columns:
2185
+ # Check override first
2186
+ if col in overrides_dict:
2187
+ column_types[col] = overrides_dict[col]
2188
+ continue
2189
+
2190
+ dtype = df[col].dtype
2191
+
2192
+ # Pandas type mapping
2193
+ if pd.api.types.is_integer_dtype(dtype):
2194
+ if dtype == np.int64 or dtype == "Int64":
2195
+ column_types[col] = "BIGINT"
2196
+ elif dtype == np.int32 or dtype == "Int32":
2197
+ column_types[col] = "INT"
2198
+ elif dtype == np.int16 or dtype == "Int16" or dtype == np.int8 or dtype == "Int8":
2199
+ column_types[col] = "SMALLINT"
2200
+ else:
2201
+ column_types[col] = "BIGINT"
2202
+ elif pd.api.types.is_float_dtype(dtype):
2203
+ if dtype == np.float64:
2204
+ column_types[col] = "DOUBLE PRECISION"
2205
+ elif dtype == np.float32:
2206
+ column_types[col] = "REAL"
2207
+ else:
2208
+ column_types[col] = "DOUBLE PRECISION"
2209
+ elif pd.api.types.is_bool_dtype(dtype):
2210
+ column_types[col] = "BOOLEAN"
2211
+ elif pd.api.types.is_datetime64_any_dtype(dtype):
2212
+ column_types[col] = "TIMESTAMP"
2213
+ elif pd.api.types.is_string_dtype(dtype) or pd.api.types.is_object_dtype(dtype):
2214
+ column_types[col] = "VARCHAR(65535)"
2215
+ elif str(dtype) == "category":
2216
+ column_types[col] = "VARCHAR(65535)"
2217
+ else:
2218
+ # Default for unknown types (including all-null columns)
2219
+ column_types[col] = "VARCHAR(65535)"
2220
+
2221
+ else: # Polars DataFrame
2222
+ for col in df.columns:
2223
+ # Check override first
2224
+ if col in overrides_dict:
2225
+ column_types[col] = overrides_dict[col]
2226
+ continue
2227
+
2228
+ dtype = df[col].dtype
2229
+
2230
+ # Polars type mapping
2231
+ if dtype in [pl.Int64, pl.UInt64]:
2232
+ column_types[col] = "BIGINT"
2233
+ elif dtype in [pl.Int32, pl.UInt32]:
2234
+ column_types[col] = "INT"
2235
+ elif dtype in [pl.Int16, pl.UInt16, pl.Int8, pl.UInt8]:
2236
+ column_types[col] = "SMALLINT"
2237
+ elif dtype == pl.Float64:
2238
+ column_types[col] = "DOUBLE PRECISION"
2239
+ elif dtype == pl.Float32:
2240
+ column_types[col] = "REAL"
2241
+ elif dtype == pl.Boolean:
2242
+ column_types[col] = "BOOLEAN"
2243
+ elif dtype == pl.Date:
2244
+ column_types[col] = "DATE"
2245
+ elif dtype in [pl.Datetime, pl.Datetime("ms"), pl.Datetime("us"), pl.Datetime("ns")]:
2246
+ column_types[col] = "TIMESTAMP"
2247
+ elif dtype == pl.Utf8 or dtype == pl.String:
2248
+ column_types[col] = "VARCHAR(65535)"
2249
+ elif dtype == pl.Categorical:
2250
+ column_types[col] = "VARCHAR(65535)"
2251
+ elif dtype == pl.Null:
2252
+ # All-null column - default to VARCHAR unless override provided
2253
+ column_types[col] = "VARCHAR(65535)"
2254
+ else:
2255
+ # Default for unknown types
2256
+ column_types[col] = "VARCHAR(65535)"
2257
+
2258
+ return column_types
2259
+
2260
+ def create_table_from_dataframe(
2261
+ self,
2262
+ df: pd.DataFrame | pl.DataFrame,
2263
+ table: str,
2264
+ schema: str,
2265
+ drop_existing_table: bool = False,
2266
+ column_types: dict[str, str] = None,
2267
+ s3_connector: S3Connector = None,
2268
+ s3_path: str = None,
2269
+ file_name: str = None,
2270
+ delete_s3_files_after: bool = True,
2271
+ truncate_before_load: bool = False,
2272
+ identity_column: str | bool = None,
2273
+ stat_update: bool = False,
2274
+ ):
2275
+ """
2276
+ Create a Redshift table from a DataFrame by uploading to S3, then loading via COPY.
2277
+
2278
+ By default, uses a temporary S3 path that is automatically cleaned up. You can also
2279
+ specify a custom S3 path (useful for appending data or keeping files for audit).
2280
+
2281
+ Parameters
2282
+ ----------
2283
+ df : pd.DataFrame | pl.DataFrame
2284
+ The DataFrame to load into Redshift.
2285
+ table : str
2286
+ The name of the target table (without schema).
2287
+ schema : str
2288
+ The schema name where the table will be created.
2289
+ drop_existing_table : bool, optional
2290
+ If True, drops the existing table before creating. Defaults to False.
2291
+ column_types : dict[str, str], optional
2292
+ Override specific column types when creating the table.
2293
+ Useful for columns with all NULL values or to enforce specific types.
2294
+ Example: {'created_date': 'DATE', 'status': 'VARCHAR(50)', 'amount': 'DECIMAL(10,2)'}
2295
+ s3_connector : S3Connector, optional
2296
+ An existing S3Connector instance to use. If provided along with s3_path,
2297
+ uses that connector's bucket and the specified path. Useful when you want
2298
+ to control where files are stored (e.g., for appending or audit trails).
2299
+ s3_path : str, optional
2300
+ Custom relative path in S3 where the parquet file will be saved.
2301
+ If not provided, uses a temporary path that gets cleaned up.
2302
+ When s3_connector is provided, this path is relative to the connector's s3_root.
2303
+ Example: 'data/my_table/' or 'staging/loads/2024/'
2304
+ file_name : str, optional
2305
+ Custom name for the parquet file (without extension).
2306
+ If not provided, defaults to '{table}_data'.
2307
+ Example: 'my_custom_file' will create 'my_custom_file.parquet'
2308
+ delete_s3_files_after : bool, optional
2309
+ If True, deletes the S3 files after loading into Redshift. Defaults to True.
2310
+ Set to False if you want to keep the files (e.g., for backup or appending).
2311
+ truncate_before_load : bool, optional
2312
+ If True, truncates the table before loading (keeps table structure).
2313
+ Defaults to False. Use this for appending with clean slate without dropping table.
2314
+ identity_column : str | bool, optional
2315
+ Add an auto-incrementing IDENTITY column to the table.
2316
+ - If True: Creates a column named 'id' as BIGINT IDENTITY(1,1)
2317
+ - If str: Creates a column with the specified name as BIGINT IDENTITY(1,1)
2318
+ - If None/False: No identity column (default)
2319
+ The identity column is added as the first column in the table.
2320
+ Example: identity_column=True or identity_column='row_id'
2321
+ stat_update : bool, optional
2322
+ If True, updates table statistics after the COPY operation (runs ANALYZE).
2323
+ Requires table/database owner privileges. Defaults to False.
2324
+ Set to True only if you have the necessary permissions and want updated statistics.
2325
+ Note: If you encounter permission errors like "only table or database owner can analyze",
2326
+ keep this as False.
2327
+
2328
+ Returns
2329
+ -------
2330
+ None
2331
+
2332
+ Examples
2333
+ --------
2334
+ # Simple usage - creates table from DataFrame (temp files auto-deleted)
2335
+ df = pd.DataFrame({'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie']})
2336
+ dc.create_table_from_dataframe(df, table='my_table', schema='my_schema')
2337
+
2338
+ # Drop and recreate table
2339
+ dc.create_table_from_dataframe(
2340
+ df,
2341
+ table='my_table',
2342
+ schema='my_schema',
2343
+ drop_existing_table=True
2344
+ )
2345
+
2346
+ # Use custom S3 path and keep files (for audit/backup)
2347
+ dc.create_table_from_dataframe(
2348
+ df,
2349
+ table='my_table',
2350
+ schema='my_schema',
2351
+ s3_path='staging/my_table_loads/',
2352
+ delete_s3_files_after=False
2353
+ )
2354
+
2355
+ # Use existing S3 connector with custom path (s3_root automatically prepended)
2356
+ s3 = S3Connector(bucket='my-bucket', s3_root='project/')
2357
+ dc.create_table_from_dataframe(
2358
+ df,
2359
+ table='my_table',
2360
+ schema='my_schema',
2361
+ s3_connector=s3,
2362
+ s3_path='loads/daily/', # Will be saved to 'project/loads/daily/'
2363
+ delete_s3_files_after=False
2364
+ )
2365
+
2366
+ # Use custom file name to keep organized data
2367
+ dc.create_table_from_dataframe(
2368
+ df,
2369
+ table='my_table',
2370
+ schema='my_schema',
2371
+ s3_path='partitioned/country=mx/',
2372
+ file_name='scores_20241216', # Creates 'scores_20241216.parquet'
2373
+ delete_s3_files_after=False
2374
+ )
2375
+
2376
+ # Append to existing table (truncate first, keep files)
2377
+ dc.create_table_from_dataframe(
2378
+ df,
2379
+ table='my_table',
2380
+ schema='my_schema',
2381
+ truncate_before_load=True,
2382
+ s3_path='data/incremental/',
2383
+ delete_s3_files_after=False
2384
+ )
2385
+
2386
+ # Use polars DataFrame
2387
+ import polars as pl
2388
+ df_polars = pl.DataFrame({'id': [1, 2, 3], 'value': [10, 20, 30]})
2389
+ dc.create_table_from_dataframe(df_polars, table='polars_table', schema='my_schema')
2390
+ """
2391
+ import uuid
2392
+ from datetime import datetime
2393
+
2394
+ self._ensure_connected()
2395
+ self._mark_activity()
2396
+
2397
+ fully_qualified_table = f"{schema}.{table}"
2398
+ working_s3 = None
2399
+ working_path = None
2400
+ is_temp_path = False
2401
+ working_bucket = None
2402
+
2403
+ try:
2404
+ # Determine S3 connector and path
2405
+ if s3_connector is not None:
2406
+ working_s3 = s3_connector
2407
+ working_bucket = s3_connector.bucket
2408
+ else:
2409
+ working_bucket = self._resolve_s3_bucket()
2410
+ working_s3 = self._get_s3_for_bucket(working_bucket)
2411
+
2412
+ if s3_path is not None:
2413
+ # Use custom path provided by user
2414
+ working_path = s3_path.strip().lstrip("/")
2415
+ if not working_path.endswith("/"):
2416
+ working_path += "/"
2417
+ # If s3_connector provided and has s3_root, prepend it to the path
2418
+ if s3_connector is not None and s3_connector.s3_root:
2419
+ working_path = f"{s3_connector.s3_root}/{working_path}"
2420
+ is_temp_path = False
2421
+ else:
2422
+ # Generate a unique temporary path
2423
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
2424
+ unique_id = str(uuid.uuid4())[:8]
2425
+ temp_path = f"_temp/dataframe_load/{timestamp}_{unique_id}/"
2426
+ # If s3_connector provided and has s3_root, prepend it
2427
+ if s3_connector is not None and s3_connector.s3_root:
2428
+ working_path = f"{s3_connector.s3_root}/{temp_path}"
2429
+ else:
2430
+ working_path = temp_path
2431
+ is_temp_path = True
2432
+
2433
+ # Use custom file name if provided, otherwise default to {table}_data
2434
+ # Strip .parquet extension if user included it (will be added later)
2435
+ if file_name is not None:
2436
+ parquet_file_name = file_name
2437
+ if parquet_file_name.endswith(".parquet"):
2438
+ parquet_file_name = parquet_file_name[:-8]
2439
+ else:
2440
+ parquet_file_name = f"{table}_data"
2441
+
2442
+ # Handle null columns only if column_types is NOT provided
2443
+ # When column_types is provided, save_dataframe handles all type conversions
2444
+ if column_types is None:
2445
+ # No column_types: Convert all-null columns to string (old behavior)
2446
+ df = self._convert_null_columns_to_string(df)
2447
+
2448
+ # Determine which columns will be in the Parquet file
2449
+ # When column_types is provided, ALL columns (including all-null) must be in Parquet
2450
+ # to match the table schema
2451
+ parquet_columns = None
2452
+ adjusted_column_types = column_types
2453
+
2454
+ # Check if table already exists and if we're appending (not dropping)
2455
+ table_exists = self._table_exists(table_name=table, schema_name=schema)
2456
+ if table_exists and not drop_existing_table:
2457
+ # Table exists and we're appending - read actual table schema
2458
+ self._logger.info(f"Table {schema}.{table} exists. Reading existing schema to match Parquet file...")
2459
+ try:
2460
+ # Use pg_catalog tables to get schema and detect identity columns
2461
+ schema_query = f"""
2462
+ SELECT a.attname as column_name,
2463
+ format_type(a.atttypid, a.atttypmod) as data_type,
2464
+ CASE WHEN d.adsrc LIKE '%identity%' THEN true ELSE false END as is_identity
2465
+ FROM pg_class c
2466
+ JOIN pg_attribute a ON c.oid = a.attrelid
2467
+ JOIN pg_namespace n ON c.relnamespace = n.oid
2468
+ LEFT JOIN pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum
2469
+ WHERE c.relkind = 'r'
2470
+ AND n.nspname = '{schema}'
2471
+ AND c.relname = '{table}'
2472
+ AND a.attnum > 0
2473
+ AND NOT a.attisdropped
2474
+ ORDER BY a.attnum
2475
+ """
2476
+ existing_schema = self.execute_sql(schema_query, fetch_all=True)
2477
+
2478
+ # Map Redshift types to our column_types format
2479
+ # EXCLUDE identity columns - they are auto-generated and should not be in the DataFrame/Parquet
2480
+ existing_column_types = {}
2481
+ if existing_schema:
2482
+ for row in existing_schema:
2483
+ col_name = row[0]
2484
+ data_type = row[1]
2485
+ is_identity = row[2] if len(row) > 2 else False
2486
+
2487
+ # Skip identity columns - they shouldn't be in the Parquet file
2488
+ if is_identity is True:
2489
+ self._logger.debug(f"Excluding identity column '{col_name}' from Parquet file")
2490
+ continue
2491
+
2492
+ # Convert Redshift types to standard format
2493
+ if "character varying" in data_type or "varchar" in data_type:
2494
+ existing_column_types[col_name] = "VARCHAR(65535)"
2495
+ elif data_type == "bigint":
2496
+ existing_column_types[col_name] = "BIGINT"
2497
+ elif data_type == "integer":
2498
+ existing_column_types[col_name] = "INT"
2499
+ elif data_type == "double precision":
2500
+ existing_column_types[col_name] = "DOUBLE PRECISION"
2501
+ elif data_type == "date":
2502
+ existing_column_types[col_name] = "DATE"
2503
+ elif "timestamp" in data_type:
2504
+ existing_column_types[col_name] = "TIMESTAMP"
2505
+ else:
2506
+ existing_column_types[col_name] = data_type.upper()
2507
+
2508
+ # Use existing table schema instead of user-provided column_types (if any)
2509
+ self._logger.info(
2510
+ "Using existing table schema for Parquet compatibility (excluding identity columns)"
2511
+ )
2512
+ adjusted_column_types = existing_column_types
2513
+ column_order_for_table = list(existing_column_types.keys())
2514
+ except Exception as e:
2515
+ self._logger.warning(f"Could not read existing table schema: {e}. Using column_types as-is.")
2516
+ if column_types:
2517
+ column_order_for_table = list(column_types.keys())
2518
+ else:
2519
+ # No column_types and couldn't read schema - use DataFrame columns
2520
+ column_order_for_table = list(df.columns)
2521
+ elif column_types:
2522
+ # Table doesn't exist or we're dropping it - use column_types as-is
2523
+ column_order_for_table = list(column_types.keys())
2524
+ else:
2525
+ # No column_types and table doesn't exist - use DataFrame columns
2526
+ column_order_for_table = list(df.columns)
2527
+
2528
+ # Important: We must include ALL columns in the Parquet file
2529
+ # Even all-null columns must be included so the Parquet file matches the table schema
2530
+
2531
+ # Ensure DataFrame has all columns (add missing columns as NaN)
2532
+ for col in column_order_for_table:
2533
+ if col not in df.columns:
2534
+ df[col] = None
2535
+
2536
+ # Reorder DataFrame columns to match the order
2537
+ df = df[column_order_for_table]
2538
+ parquet_columns = column_order_for_table
2539
+
2540
+ # If we haven't already set adjusted_column_types (from existing table), adjust for DECIMAL
2541
+ if adjusted_column_types is None:
2542
+ # No adjusted_column_types yet - use DataFrame columns as-is
2543
+ pass
2544
+ elif column_types and adjusted_column_types == column_types:
2545
+ # We have column_types but haven't adjusted them yet - check for DECIMAL
2546
+ # Adjust ALL column types for consistency:
2547
+ # DECIMAL types must be DOUBLE PRECISION (Redshift can't load Parquet double into numeric)
2548
+ adjusted_column_types = column_types.copy()
2549
+ for col, col_type in column_types.items():
2550
+ base_type = col_type.split("(")[0].strip().upper()
2551
+ if base_type == "DECIMAL":
2552
+ adjusted_column_types[col] = "DOUBLE PRECISION"
2553
+ self._logger.debug(f"Adjusted {col} from DECIMAL to DOUBLE PRECISION for Parquet compatibility")
2554
+
2555
+ # Save DataFrame to S3
2556
+ # Pass column_types to ensure Parquet schema matches table schema
2557
+ # When s3_connector is provided and we've already prepended s3_root to working_path,
2558
+ # pass s3_root="" to save_dataframe to prevent double s3_root
2559
+ self._logger.debug(f"Saving DataFrame to Parquet with column_types: {adjusted_column_types}")
2560
+ self._logger.debug(f"Parquet columns that will be saved: {parquet_columns}")
2561
+
2562
+ # Determine s3_root parameter for save_dataframe
2563
+ # If we used s3_connector and already prepended its s3_root, pass empty string
2564
+ save_s3_root = "" if (s3_connector is not None and s3_connector.s3_root) else None
2565
+
2566
+ working_s3.save_dataframe(
2567
+ df,
2568
+ directory=working_path,
2569
+ file_name=parquet_file_name,
2570
+ file_format="parquet",
2571
+ column_types=adjusted_column_types,
2572
+ s3_root=save_s3_root,
2573
+ )
2574
+
2575
+ # Load from S3 to Redshift
2576
+ # Use specific file path to avoid Redshift trying to read hidden files in directory
2577
+ parquet_file_path = f"{working_path}{parquet_file_name}.parquet"
2578
+ self._logger.debug(f"Loading Parquet from: s3://{working_bucket}/{parquet_file_path}")
2579
+ self._logger.debug(f"Adjusted column types for table: {adjusted_column_types}")
2580
+ # NOTE: column_list is NOT used for PARQUET format - Parquet columns must match table order
2581
+ self.load_from_s3(
2582
+ table=table,
2583
+ schema=schema,
2584
+ relative_path=parquet_file_path,
2585
+ s3_bucket=working_bucket,
2586
+ format="PARQUET",
2587
+ drop_existing_table=drop_existing_table,
2588
+ truncate_before_load=truncate_before_load,
2589
+ column_types=adjusted_column_types, # Use adjusted types (DECIMAL -> DOUBLE PRECISION)
2590
+ column_list=None, # Do NOT use column_list for PARQUET - columns must match table order
2591
+ identity_column=identity_column,
2592
+ stat_update=stat_update,
2593
+ )
2594
+
2595
+ self._logger.info(f"Successfully loaded data into {fully_qualified_table}")
2596
+
2597
+ # Clean up S3 files if requested (always delete temp paths on success)
2598
+ should_delete = delete_s3_files_after or is_temp_path
2599
+ if should_delete and working_s3 and working_path:
2600
+ try:
2601
+ files_to_delete = working_s3.list_files(prefix=working_path, bucket=working_bucket)
2602
+ for file_key in files_to_delete:
2603
+ try:
2604
+ working_s3.delete_file(file_key, bucket=working_bucket)
2605
+ except Exception as delete_error:
2606
+ self._logger.warning(f"Failed to delete file {file_key}: {delete_error}")
2607
+ except Exception as cleanup_error:
2608
+ self._logger.warning(f"Failed to clean up S3 files: {cleanup_error}")
2609
+
2610
+ except Exception as e:
2611
+ # On error, keep files for debugging
2612
+ self._logger.info(f"Keeping S3 files for debugging at: s3://{working_bucket}/{working_path}")
2613
+ log_and_raise_error(self._logger, f"Error creating table from DataFrame: {e}")
2614
+ finally:
2615
+ self._start_idle_timer()