rdxz2-utill 0.0.2__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
utill/my_bq.py CHANGED
@@ -1,358 +1,726 @@
1
- import humanize
1
+ import datetime
2
2
  import math
3
3
  import os
4
4
  import shutil
5
-
5
+ import textwrap
6
+ import time
6
7
  from enum import Enum
7
- from google.cloud import bigquery, storage
8
+ from enum import StrEnum
9
+ from enum import auto
10
+
11
+ from google.cloud import bigquery
12
+ from google.cloud.exceptions import NotFound
13
+ from humanize import naturalsize
14
+ from humanize import precisedelta
8
15
  from loguru import logger
9
- from textwrap import dedent
10
-
11
- from .my_const import ByteSize
12
- from .my_csv import read_header, combine as csv_combine, compress
13
- from .my_datetime import current_datetime_str
14
- from .my_env import envs
15
- from .my_file import make_sure_path_is_directory
16
- from .my_gcs import GCS
17
- from .my_queue import ThreadingQ
18
- from .my_string import replace_nonnumeric
19
- from .my_xlsx import csv_to_xlsx
20
-
21
- MAP__PYTHON_DTYPE__BQ_DTYPE = {
22
- int: 'INTEGER',
23
- str: 'STRING',
24
- float: 'STRING',
16
+
17
+ from . import my_csv
18
+ from . import my_datetime
19
+ from . import my_env
20
+ from . import my_gcs
21
+ from . import my_string
22
+ from . import my_xlsx
23
+
24
+
25
+ PY_DATA_TYPE__BQ_DATA_TYPE = {
26
+ int: "INTEGER",
27
+ str: "STRING",
28
+ float: "STRING",
25
29
  }
26
30
 
27
31
 
28
- class LoadStrategy(Enum):
29
- OVERWRITE = 1
30
- APPEND = 2
32
+ class DataFileFormat(StrEnum):
33
+ CSV = "CSV"
34
+ JSON = "JSON"
35
+ AVRO = "AVRO"
36
+ PARQUET = "PARQUET"
37
+ ORC = "ORC"
31
38
 
32
39
 
33
- class Dtype:
34
- INT64 = 'INT64'
35
- INTEGER = 'INTEGER'
36
- FLOAT64 = 'FLOAT64'
40
+ class DataFileCompression(StrEnum):
41
+ GZIP = "GZIP"
42
+ SNAPPY = "SNAPPY"
43
+
37
44
 
38
- DECIMAL = 'DECIMAL'
45
+ class LoadStrategy(Enum):
46
+ OVERWRITE = auto()
47
+ APPEND = auto()
39
48
 
40
- STRING = 'STRING'
41
- JSON = 'JSON'
42
49
 
43
- DATE = 'DATE'
44
- TIME = 'TIME'
45
- DATETIME = 'DATETIME'
46
- TIMESTAMP = 'TIMESTAMP'
50
+ class Dtype:
51
+ INT64 = "INT64"
52
+ INTEGER = "INTEGER"
53
+ FLOAT64 = "FLOAT64"
47
54
 
48
- BOOL = 'BOOL'
55
+ DECIMAL = "DECIMAL"
49
56
 
50
- ARRAY_INT64 = 'ARRAY<INT64>'
51
- ARRAY_INTEGER = 'ARRAY<INTEGER>'
52
- ARRAY_FLOAT64 = 'ARRAY<FLOAT64>'
53
- ARRAY_STRING = 'ARRAY<STRING>'
54
- ARRAY_JSON = 'ARRAY<JSON>'
55
- ARRAY_DATE = 'ARRAY<DATE>'
56
- ARRAY_DATETIME = 'ARRAY<DATETIME>'
57
- ARRAY_TIMESTAMP = 'ARRAY<TIMESTAMP>'
58
- ARRAY_BOOL = 'ARRAY<BOOL>'
57
+ STRING = "STRING"
58
+ JSON = "JSON"
59
59
 
60
+ DATE = "DATE"
61
+ TIME = "TIME"
62
+ DATETIME = "DATETIME"
63
+ TIMESTAMP = "TIMESTAMP"
60
64
 
61
- class BQ():
62
- def __init__(self, project: str = None):
63
- self.project = project or envs.GCP_PROJECT_ID
65
+ BOOL = "BOOL"
64
66
 
65
- self.client = bigquery.Client(project=self.project)
66
- logger.debug(f'BQ client open, project: {self.project or "<application-default>"}')
67
+ ARRAY_INT64 = "ARRAY<INT64>"
68
+ ARRAY_INTEGER = "ARRAY<INTEGER>"
69
+ ARRAY_FLOAT64 = "ARRAY<FLOAT64>"
70
+ ARRAY_STRING = "ARRAY<STRING>"
71
+ ARRAY_JSON = "ARRAY<JSON>"
72
+ ARRAY_DATE = "ARRAY<DATE>"
73
+ ARRAY_DATETIME = "ARRAY<DATETIME>"
74
+ ARRAY_TIMESTAMP = "ARRAY<TIMESTAMP>"
75
+ ARRAY_BOOL = "ARRAY<BOOL>"
67
76
 
68
- def __enter__(self):
69
- return self
70
77
 
71
- def __exit__(self, exc_type, exc_value, exc_tb):
72
- self.close_client()
78
+ class BQ:
79
+ def __init__(self, location: str | None = None, project_id: str = None):
80
+ if project_id is None and my_env.envs.GCP_PROJECT_ID is None:
81
+ logger.warning("Using ADC for BigQuery authentication")
73
82
 
74
- def execute_query(self, query: str | list[str], dry_run: bool = False, parameters: dict = {}) -> bigquery.QueryJob:
75
- multi = type(query) == list
76
- if multi:
77
- query = '\n'.join([x if str(x).strip().endswith(';') else x + ';' for x in query if x])
78
- else:
79
- query = query.strip()
83
+ # if location is None and my_env.envs.GCP_REGION is None:
84
+ # raise ValueError('GCP region must be set in environment variables.')
80
85
 
81
- # Build paramters
86
+ self.client = bigquery.Client(
87
+ project=project_id or my_env.envs.GCP_PROJECT_ID,
88
+ location=location or my_env.envs.GCP_REGION,
89
+ )
90
+ logger.debug(f"BQ client open, project: {self.client.project}")
91
+
92
+ # MARK: Query execution
93
+
94
+ def execute_query(
95
+ self,
96
+ query: str | list[str],
97
+ parameters: dict = {},
98
+ dry_run: bool = False,
99
+ temporary_table: bool = False,
100
+ ) -> bigquery.QueryJob:
101
+ # Reconstruct query, handle multiple queries in a single job
102
+ is_multi = isinstance(query, list)
103
+ queries = query if is_multi else [query]
104
+ queries = [textwrap.dedent(q).strip() for q in queries]
105
+ queries = [
106
+ q if q.endswith(";") else q + ";" for q in queries
107
+ ] # Append ';' character for each query
108
+ query = "\n".join(queries)
109
+
110
+ # Evaluate parameter
82
111
  query_parameters = []
83
112
  for parameter, value in parameters.items():
84
- if type(value) == list:
85
- query_parameters.append(bigquery.ArrayQueryParameter(parameter, MAP__PYTHON_DTYPE__BQ_DTYPE[type(value[0])], value))
113
+ is_array = isinstance(value, list)
114
+ value_type_py = type(value[0]) if is_array else type(value)
115
+ if value_type_py not in PY_DATA_TYPE__BQ_DATA_TYPE:
116
+ raise ValueError(
117
+ f"Unsupported type for parameter {parameter}: {value_type_py}. Supported types are: {list(PY_DATA_TYPE__BQ_DATA_TYPE.keys())}"
118
+ )
119
+
120
+ value_type_bq = PY_DATA_TYPE__BQ_DATA_TYPE[value_type_py]
121
+
122
+ # Handle data type conversions
123
+ if value_type_py == datetime.date:
124
+ value = (
125
+ [v.strftime("%Y-%m-%d") for v in value]
126
+ if is_array
127
+ else value.strftime("%Y-%m-%d")
128
+ )
129
+
130
+ if is_array:
131
+ query_parameters.append(
132
+ bigquery.ArrayQueryParameter(parameter, value_type_bq, value)
133
+ )
86
134
  else:
87
- query_parameters.append(bigquery.ScalarQueryParameter(parameter, MAP__PYTHON_DTYPE__BQ_DTYPE[type(value)], value))
135
+ query_parameters.append(
136
+ bigquery.ScalarQueryParameter(parameter, value_type_bq, value)
137
+ )
88
138
 
89
- logger.debug(f'🔎 Query:\n{query}')
90
- query_job_config = bigquery.QueryJobConfig(dry_run=dry_run, query_parameters=query_parameters)
139
+ logger.debug(f"🔎 Query:\n{query}")
140
+ query_job_config = bigquery.QueryJobConfig(
141
+ dry_run=dry_run, query_parameters=query_parameters
142
+ )
143
+ if temporary_table:
144
+ query_job_config.destination = None
145
+ t = time.time()
91
146
  query_job = self.client.query(query, job_config=query_job_config)
92
- query_job.result() # Wait query execution
147
+ (
148
+ logger.info(
149
+ f"Job tracking: https://console.cloud.google.com/bigquery?project={self.client.project}&j=bq:{self.client.location}:{query_job.job_id}&page=queryresults"
150
+ )
151
+ if not dry_run
152
+ else None
153
+ )
154
+ query_job.result() # Wait for the job to complete
155
+ elapsed = precisedelta(datetime.timedelta(seconds=time.time() - t))
93
156
 
94
- if not multi:
95
- logger.debug(f'[Job ID] {query_job.job_id}, [Processed] {humanize.naturalsize(query_job.total_bytes_processed)}, [Billed] {humanize.naturalsize(query_job.total_bytes_billed)}, [Affected] {query_job.num_dml_affected_rows or 0} row(s)',)
157
+ if not is_multi:
158
+ logger.info(
159
+ f"[Job ID] {query_job.job_id}, [Processed] {naturalsize(query_job.total_bytes_processed)}, [Billed] {naturalsize(query_job.total_bytes_billed)}, [Affected] {query_job.num_dml_affected_rows or 0} row(s), [Elapsed] {elapsed}",
160
+ )
96
161
  else:
97
- logger.debug(f'[Job ID] {query_job.job_id}')
162
+ logger.info(f"[Job ID] {query_job.job_id} [Elapsed] {elapsed}")
98
163
 
99
- jobs: list[bigquery.QueryJob] = self.client.list_jobs(parent_job=query_job.job_id)
100
- [logger.debug(f'[Script ID] {job.job_id}, [Processed] {humanize.naturalsize(job.total_bytes_processed)}, [Billed] {humanize.naturalsize(job.total_bytes_billed)}, [Affected] {job.num_dml_affected_rows or 0} row(s)',) for job in jobs]
164
+ jobs: list[bigquery.QueryJob] = list(
165
+ self.client.list_jobs(parent_job=query_job.job_id)
166
+ )
167
+ [
168
+ logger.info(
169
+ f"[Script ID] {job.job_id}, [Processed] {naturalsize(job.total_bytes_processed)}, [Billed] {naturalsize(job.total_bytes_billed)}, [Affected] {job.num_dml_affected_rows or 0} row(s)",
170
+ )
171
+ for job in jobs
172
+ ]
101
173
 
102
174
  return query_job
103
175
 
104
- def create_table(self, bq_table_fqn: str, schema: list[bigquery.SchemaField], partition_col: str, cluster_cols: list[str]):
105
- table = bigquery.Table(bq_table_fqn, schema=schema)
106
-
107
- if partition_col:
108
- table.time_partitioning = bigquery.TimePartitioning(field=partition_col)
109
- table.partitioning_type = 'DAY'
110
-
111
- if cluster_cols:
112
- table.clustering_fields = cluster_cols
113
-
114
- bq_table = self.client.create_table(table)
115
- logger.info(f'✅ Table created: {bq_table_fqn}')
116
- return bq_table
117
-
118
- def drop_table(self, bq_table_fqn: str):
119
- self.client.delete_table(bq_table_fqn)
120
- logger.info(f'✅ Table dropped: {bq_table_fqn}')
121
-
122
- def load_data_into(self, bq_table_fqn: str, gcs_path: list[str] | str, cols: dict[str, Dtype], partition_col: str = None, cluster_cols: list[str] = None, overwrite: bool = False):
123
- if type(gcs_path) == str:
124
- gcs_path = [gcs_path]
125
- gcs_path_str = ',\n'.join([f' \'{x}\'' for x in gcs_path])
126
-
127
- load_data_keyword = 'OVERWRITE' if overwrite else 'INTO'
128
- cols_str = ',\n'.join([f' `{x}` {y}' for x, y in cols.items()])
129
- cluster_cols_str = ','.join([f'`{x}`' for x in cluster_cols]) if cluster_cols else None
130
- query = dedent(
131
- f'''
132
- LOAD DATA {load_data_keyword} `{bq_table_fqn}` (
133
- {cols_str}
176
+ # MARK: Table operations
177
+
178
+ def create_table(
179
+ self,
180
+ dst_table_fqn: str,
181
+ query: str,
182
+ query_parameters: dict = {},
183
+ *,
184
+ description: str | None = None,
185
+ schema: list[dict] | None = None,
186
+ partition_by: str | None = None,
187
+ clustering_fields: list[str] | None = None,
188
+ expiration_timestamp_utc: datetime.datetime | None = None,
189
+ require_partition_filter: bool = False,
190
+ replace: bool = False,
191
+ ):
192
+ self.raise_for_invalid_table_fqn(dst_table_fqn)
193
+
194
+ # Construct table options
195
+ logger.debug("Constructing table options ...")
196
+ table_options = []
197
+ if expiration_timestamp_utc:
198
+ table_options.append(
199
+ f" expiration_timestamp='{expiration_timestamp_utc.isoformat()}'"
134
200
  )
135
- {f"PARTITION BY `{partition_col}`" if partition_col is not None else "-- No partition column provided"}
136
- {f"CLUSTER BY {cluster_cols_str}" if cluster_cols_str is not None else "-- No cluster column provided"}
137
- FROM FILES(
138
- skip_leading_rows=1,
139
- allow_quoted_newlines=true,
140
- format='csv',
141
- compression='gzip',
142
- uris = [
143
- {gcs_path_str}
144
- ]
145
- );
146
- '''
201
+ if partition_by and require_partition_filter:
202
+ table_options.append(" require_partition_filter=TRUE")
203
+ if description:
204
+ table_options.append(f" description='{description}'")
205
+
206
+ # Check if table exists
207
+ logger.debug("Checking if destination table exists ...")
208
+ dst_table_project_id, dst_table_dataset_id, dst_table_id = (
209
+ self.get_table_fqn_parts(dst_table_fqn)
147
210
  )
148
-
149
- logger.debug(f'⌛ Load data into: {bq_table_fqn}')
150
- query_job = self.execute_query(query)
151
- logger.info(f'✅ Load data into: {bq_table_fqn}')
152
- return query_job
153
-
154
- def export_data(self, query: str, gcs_path: str, pre_query: str = None):
155
- if '*' not in gcs_path:
156
- raise ValueError('GCS path need to have a single \'*\' wildcard character')
157
-
158
- query = dedent(
159
- f'''
160
- EXPORT DATA OPTIONS (
161
- uri='{gcs_path}',
162
- format='csv',
163
- compression='gzip',
164
- overwrite=true,
165
- header=true,
166
- field_delimiter=',')
167
- AS (
168
- {query}
169
- );
170
- '''
211
+ table_exist = self.is_table_exists(
212
+ project_id=dst_table_project_id,
213
+ dataset_id=dst_table_dataset_id,
214
+ table_id=dst_table_id,
171
215
  )
172
216
 
173
- if pre_query:
174
- query = [pre_query, query]
175
-
176
- logger.debug(f'⌛ Export data into: {gcs_path}')
177
- query_job = self.execute_query(query)
178
- logger.info(f'✅ Exported data into: {gcs_path}')
179
- return query_job
180
-
181
- def upload_csv(self, src_filename: str, bq_table_fqn: str, cols: dict[str, Dtype], partition_col: str = None, cluster_cols: list[str] = None, load_strategy: LoadStrategy = LoadStrategy.APPEND):
182
- # <<----- START: Validation
183
-
184
- if load_strategy not in LoadStrategy:
185
- raise ValueError('Invalid load strategy')
217
+ # Construct beautiful query string
218
+ if table_exist and not replace:
219
+ logger.debug("Table exists, constructing INSERT query ...")
220
+ query_parts = [f"INSERT INTO `{dst_table_fqn}`"]
221
+ if schema:
222
+ schema_str = ",\n".join([column["name"] for column in schema])
223
+ query_parts.append(f"(\n{schema_str}\n)")
224
+ if table_options:
225
+ table_options_str = ",\n".join(table_options)
226
+ query_parts.append(f"OPTIONS (\n{table_options_str}\n)")
227
+ else:
228
+ logger.debug("Table not exist, constructing CREATE TABLE query ...")
229
+ query_parts = [
230
+ f"CREATE OR REPLACE TABLE `{dst_table_fqn}`",
231
+ ]
232
+ if schema:
233
+ schema_str = ",\n".join(
234
+ [f' {column["name"]} {column["data_type"]}' for column in schema]
235
+ )
236
+ query_parts.append(f"(\n{schema_str}\n)")
237
+ if partition_by:
238
+ query_parts.append(f"PARTITION BY {partition_by}")
239
+ if clustering_fields:
240
+ clustering_fields_str = ", ".join(
241
+ [f"`{field}`" for field in clustering_fields]
242
+ )
243
+ query_parts.append(f"CLUSTER BY {clustering_fields_str}")
244
+ if table_options:
245
+ table_options_str = ",\n".join(table_options)
246
+ query_parts.append(f"OPTIONS (\n{table_options_str}\n)")
247
+ query_parts.append("AS")
248
+ query_parts.append(textwrap.dedent(query).strip())
249
+
250
+ # Execute
251
+ logger.debug("Executing query ...")
252
+ query = "\n".join(query_parts)
253
+ self.execute_query(query, parameters=query_parameters)
186
254
 
187
- if not src_filename.endswith('.csv'):
188
- raise ValueError('Please provide file path with .csv extension!')
255
+ def drop_table(self, bq_table_fqn: str):
256
+ logger.info(f"Dropping table: {bq_table_fqn} ...")
257
+ self.raise_for_invalid_table_fqn(bq_table_fqn)
258
+ self.client.delete_table(bq_table_fqn, not_found_ok=True)
259
+
260
+ # MARK: Table data
261
+
262
+ def load_data(
263
+ self,
264
+ src_gcs_uri: str,
265
+ dst_table_fqn: str,
266
+ *,
267
+ schema: list[dict] | None = None,
268
+ partition_by: str | None = None,
269
+ clustering_fields: list[str] | None = None,
270
+ field_delimiter: str = ",",
271
+ load_strategy: LoadStrategy = LoadStrategy.APPEND,
272
+ format: DataFileFormat = DataFileFormat.CSV,
273
+ compression=None,
274
+ ):
275
+
276
+ self.raise_for_invalid_table_fqn(dst_table_fqn)
277
+
278
+ logger.debug(f"Loading CSV from {src_gcs_uri} into {dst_table_fqn} ...")
279
+
280
+ # Construct LOAD options
281
+ logger.debug("Constructing LOAD options ...")
282
+ load_options = [ # https://cloud.google.com/bigquery/docs/reference/standard-sql/load-statements#load_option_list
283
+ f" format='{format}'",
284
+ f" uris=['{src_gcs_uri}']",
285
+ ]
286
+ if format == DataFileFormat.CSV:
287
+ load_options.append(" skip_leading_rows=1")
288
+ load_options.append(f" field_delimiter='{field_delimiter}'")
289
+ load_options.append(" allow_quoted_newlines=true")
290
+ if compression:
291
+ load_options.append(f" compression='{compression}'")
292
+ load_options_str = ",\n".join(load_options)
293
+
294
+ # Construct beautiful query string
295
+ logger.debug("Constructing LOAD query ...")
296
+ schema_str = ",\n".join(
297
+ [f' {column["name"]} {column["data_type"]}' for column in schema]
298
+ )
299
+ query_parts = [
300
+ f'LOAD DATA {"OVERWRITE" if load_strategy == LoadStrategy.OVERWRITE else "INTO"} `{dst_table_fqn}` (\n{schema_str}\n)'
301
+ ]
302
+ if partition_by:
303
+ query_parts.append(f"PARTITION BY {partition_by}")
304
+ if clustering_fields:
305
+ clustering_fields_str = ", ".join(
306
+ [f"`{field}`" for field in clustering_fields]
307
+ )
308
+ query_parts.append(f"CLUSTER BY {clustering_fields_str}")
309
+ query_parts.append(f"FROM FILES (\n{load_options_str}\n)")
310
+ query = "\n".join(query_parts)
311
+
312
+ # Execute
313
+ logger.debug("Executing query ...")
314
+ self.execute_query(query)
315
+
316
+ def export_data(
317
+ self,
318
+ query: str,
319
+ dst_gcs_uri: str,
320
+ *,
321
+ parameters: dict = {},
322
+ format: DataFileFormat = DataFileFormat.CSV,
323
+ compression: DataFileCompression | None = None,
324
+ header: bool = True,
325
+ delimiter: str = ",",
326
+ ):
327
+ logger.debug(f"Exporting query into {dst_gcs_uri} ...")
328
+
329
+ # GCS uri validation
330
+ if (
331
+ format == DataFileFormat.CSV
332
+ and compression == DataFileCompression.GZIP
333
+ and not dst_gcs_uri.endswith(".gz")
334
+ ):
335
+ raise ValueError(
336
+ "GCS path need to ends with .gz if using compression = GCSCompression.GZIP"
337
+ )
338
+ elif (
339
+ format == DataFileFormat.CSV
340
+ and compression != DataFileCompression.GZIP
341
+ and not dst_gcs_uri.endswith(".csv")
342
+ ):
343
+ raise ValueError(
344
+ "GCS path need to ends with .csv if using format = GCSExportFormat.CSV"
345
+ )
346
+ elif format == DataFileFormat.PARQUET and not dst_gcs_uri.endswith(".parquet"):
347
+ raise ValueError(
348
+ "GCS path need to ends with .parquet if using format = GCSExportFormat.PARQUET"
349
+ )
189
350
 
190
- if partition_col is not None:
191
- if partition_col not in cols.keys():
192
- raise ValueError(f'Partition \'{partition_col}\' not exists in columns!')
193
- if cluster_cols is not None:
194
- if cluster_cols not in cols.keys():
195
- raise ValueError(f'Cluster \'{cluster_cols}\' not exists in columns!')
351
+ # Construct options
352
+ logger.debug("Constructing EXPORT options ...")
353
+ options = [
354
+ f" uri='{dst_gcs_uri}'",
355
+ f" format='{format}'",
356
+ " overwrite=TRUE",
357
+ ]
358
+ if format == DataFileFormat.CSV:
359
+ options.append(
360
+ f" field_delimiter='{delimiter}'",
361
+ )
362
+ if header:
363
+ options.append(
364
+ f' header={"true" if header else "false"}',
365
+ )
366
+ if compression:
367
+ options.append(f" compression='{compression}'")
368
+ options_str = ",\n".join(options)
369
+
370
+ # Construct beautiful query string
371
+ logger.debug("Constructing EXPORT query ...")
372
+ query = (
373
+ f"EXPORT DATA OPTIONS (\n"
374
+ f"{options_str}\n"
375
+ f")\n"
376
+ f"AS (\n"
377
+ f"{textwrap.dedent(query).strip()}\n"
378
+ f");"
379
+ )
196
380
 
197
- # Build list of columns with its datatypes
198
- csv_cols = set(read_header(src_filename))
199
- excessive_cols = set(cols.keys()) - set(csv_cols)
200
- if excessive_cols:
201
- raise ValueError(f'{len(excessive_cols)} columns not exists in CSV file: {", ".join(excessive_cols)}')
202
- nonexistent_cols = set(csv_cols) - set(cols.keys())
203
- if nonexistent_cols:
204
- raise ValueError(f'{len(nonexistent_cols)} columns from CSV are missing: {", ".join(nonexistent_cols)}')
381
+ # Execute
382
+ logger.debug("Executing query ...")
383
+ self.execute_query(query=query, parameters=parameters)
384
+
385
+ def upload_csv(
386
+ self,
387
+ src_filepath: str,
388
+ dst_table_fqn: str,
389
+ schema: list[dict] | None = None,
390
+ gcs_bucket: str | None = None,
391
+ partition_by: str = None,
392
+ clustering_fields: list[str] = None,
393
+ compression: DataFileCompression | None = None,
394
+ load_strategy: LoadStrategy = LoadStrategy.APPEND,
395
+ ):
396
+ self.raise_for_invalid_table_fqn(dst_table_fqn)
397
+
398
+ if compression == DataFileCompression.GZIP and not src_filepath.endswith(".gz"):
399
+ raise ValueError(
400
+ "Please provide file path with .gz extension if using compression = GZIP"
401
+ )
402
+ elif not src_filepath.endswith(".csv"):
403
+ raise ValueError("Please provide file path with .csv extension")
205
404
 
206
- # END: Validation ----->>
405
+ src_filename, src_fileextension = os.path.splitext(src_filepath)
406
+ src_filename = os.path.basename(src_filename) # Only get filename
207
407
 
208
- # <<----- START: Upload to GCS
408
+ # # <<----- START: Upload to GCS
209
409
 
210
- gcs = GCS(self.project)
211
- tmp_dir = f'tmp/upload__{current_datetime_str()}'
410
+ # gcs = GCS(self.project_id)
411
+ # tmp_dir = f'tmp/upload__{current_datetime_str()}'
212
412
 
213
- # This will compress while splitting the compressed file to a certain bytes size because of GCS 4GB file limitation
214
- # A single file can produce more than one compressed file in GCS
215
- def producer(src_file: str):
216
- for dst_file in compress(src_file, keep=True, max_size_bytes=ByteSize.GB * 3):
217
- yield (dst_file, )
413
+ # # This will compress while splitting the compressed file to a certain bytes size because of GCS 4GB file limitation
414
+ # # A single file can produce more than one compressed file in GCS
415
+ # def producer(src_file: str):
416
+ # for dst_file in compress(src_file,
417
+ # keep=True, max_size_bytes=ByteSize.GB * 3):
418
+ # yield (dst_file, )
218
419
 
219
- def consumer(dst_file: str):
220
- remote_file_name = f'{tmp_dir}/{replace_nonnumeric(os.path.basename(dst_file), "_").lower()}.csv.gz'
221
- logger.debug(f'Uploading {dst_file} to {remote_file_name}...')
222
- blob = gcs.upload(dst_file, remote_file_name, mv=True)
223
- return blob
420
+ # def consumer(dst_file: str):
421
+ # remote_file_name = f'{tmp_dir}/{replace_nonnumeric(os.path.basename(dst_file), "_").lower()}.csv.gz'
422
+ # logger.debug(f'Uploading {dst_file} to {remote_file_name}...')
423
+ # blob = gcs.upload(dst_file, remote_file_name, move=True)
424
+ # return blob
224
425
 
225
- blobs: list[storage.Blob]
226
- _, blobs = ThreadingQ().add_producer(producer, src_filename).add_consumer(consumer).execute()
426
+ # blobs: list[storage.Blob]
427
+ # _, blobs = ThreadingQ().add_producer(producer, src_filename).add_consumer(consumer).execute()
227
428
 
228
- # END: Upload to GCS ----->>
429
+ # # END: Upload to GCS ----->>
229
430
 
230
- # <<----- START: Load to BQ
431
+ # Upload to GCS
432
+ # TODO: Re-implement the producer-consumer model to upload multiple files
433
+ gcs = my_gcs.GCS(bucket=gcs_bucket, project_id=self.client.project)
434
+ dst_blobpath = f'tmp/my_bq/{my_datetime.get_current_datetime_str()}/{my_string.replace_nonnumeric(src_filename, "_").lower()}{src_fileextension}'
435
+ gcs.upload(src_filepath, dst_blobpath)
231
436
 
437
+ # Load to BQ
232
438
  try:
233
- gcs_filename_fqns = [f'gs://{blob.bucket.name}/{blob.name}' for blob in blobs]
234
- match load_strategy:
235
- case LoadStrategy.OVERWRITE:
236
- self.load_data_into(bq_table_fqn, gcs_filename_fqns, cols, partition_col=partition_col, cluster_cols=cluster_cols, overwrite=True)
237
- case LoadStrategy.APPEND:
238
- self.load_data_into(bq_table_fqn, gcs_filename_fqns, cols, partition_col=partition_col, cluster_cols=cluster_cols)
239
- case _:
240
- return ValueError(f'Load strategy not recognized: {load_strategy}')
241
- except Exception as e:
242
- raise e
439
+ self.load_data(
440
+ f"gs://{gcs.bucket.name}/{dst_blobpath}",
441
+ dst_table_fqn,
442
+ schema=schema,
443
+ partition_by=partition_by,
444
+ clustering_fields=clustering_fields,
445
+ format=DataFileFormat.CSV,
446
+ compression=compression,
447
+ load_strategy=load_strategy,
448
+ )
449
+ except:
450
+ raise
243
451
  finally:
244
- [GCS.remove_blob(blob) for blob in blobs]
245
-
246
- # END: Load to BQ ----->>
247
-
248
- def download_csv(self, query: str, dst_filename: str, combine: bool = True, pre_query: str = None):
249
- if not dst_filename.endswith('.csv'):
250
- raise ValueError('Destination filename must ends with .csv!')
251
-
252
- dirname = os.path.dirname(dst_filename)
253
- make_sure_path_is_directory(dirname)
254
-
255
- # Remove & recreate existing folder
256
- if os.path.exists(dirname):
257
- shutil.rmtree(dirname)
258
- os.makedirs(dirname, exist_ok=True)
259
-
260
- # Export data into GCS
261
- current_time = current_datetime_str()
262
- gcs_path = f'gs://{envs.GCS_BUCKET}/tmp/unload__{current_time}/*.csv.gz'
263
- self.export_data(query, gcs_path, pre_query)
264
-
265
- # Download into local machine
266
- gcs = GCS(self.project)
267
- logger.info('Downloads from GCS...')
268
- downloaded_filenames = []
269
- for blob in gcs.list(f'tmp/unload__{current_time}/'):
270
- file_path_part = os.path.join(dirname, blob.name.split('/')[-1])
271
- gcs.download(blob, file_path_part)
272
- downloaded_filenames.append(file_path_part)
273
-
274
- # Combine the file and clean up the file chunks
275
- if combine:
276
- logger.info('Combine downloaded csv...')
277
- csv_combine(downloaded_filenames, dst_filename)
278
- shutil.rmtree(dirname)
279
-
280
- return dst_filename
281
-
282
- def download_xlsx(self, src_table_fqn: str, dst_filename: str, xlsx_row_limit: int = 950000):
283
- if not dst_filename.endswith('.xlsx'):
284
- raise ValueError('Destination filename must ends with .xlsx!')
452
+ gcs.delete_blob(dst_blobpath)
453
+
454
+ def download_csv(
455
+ self,
456
+ query: str,
457
+ dst_filepath: str,
458
+ *,
459
+ gcs_bucket: str | None = None,
460
+ query_parameters: dict = {},
461
+ csv_row_limit: int | None = None,
462
+ ) -> str | list[str]:
463
+ if not dst_filepath.endswith(".csv"):
464
+ raise ValueError("Destination filename must ends with .csv")
465
+
466
+ # Init
467
+ gcs = my_gcs.GCS(bucket=gcs_bucket, project_id=self.client.project)
468
+
469
+ # Generic function to export-download-combine csv file from BQ->GCS->local
470
+ def _export_download_combine(
471
+ query: str,
472
+ dst_gcs_prefix: str,
473
+ dst_filepath: str,
474
+ query_parameters: dict = {},
475
+ ):
476
+ # Init tmp directory
477
+ tmp_dirname = f"/tmp/my_bq_{my_datetime.get_current_datetime_str()}"
478
+ if os.path.exists(tmp_dirname):
479
+ shutil.rmtree(tmp_dirname, ignore_errors=True)
480
+ os.makedirs(tmp_dirname, exist_ok=True)
481
+ logger.debug(f"Temporary directory created: {tmp_dirname}")
482
+
483
+ try:
484
+ # Export to GCS
485
+ dst_gcs_uri = f"gs://{gcs.bucket.name}/{dst_gcs_prefix}/*.csv.gz"
486
+ self.export_data(
487
+ query,
488
+ dst_gcs_uri,
489
+ parameters=query_parameters,
490
+ format=DataFileFormat.CSV,
491
+ compression=DataFileCompression.GZIP,
492
+ )
493
+
494
+ # Download from GCS
495
+ local_tmp_filepaths = []
496
+ for tmp_blobs in gcs.list_blobs(dst_gcs_prefix):
497
+ local_tmp_filepath = os.path.join(
498
+ tmp_dirname, tmp_blobs.name.split("/")[-1]
499
+ )
500
+ gcs.download(tmp_blobs, local_tmp_filepath, move=True)
501
+ # logger.debug(f'Downloaded {tmp_blobs.name} to {local_tmp_filepath}')
502
+ local_tmp_filepaths.append(local_tmp_filepath)
503
+
504
+ # Combine downloaded files
505
+ my_csv.combine(
506
+ local_tmp_filepaths, dst_filepath, gzip=True, delete=True
507
+ )
508
+ except:
509
+ raise
510
+ finally:
511
+ shutil.rmtree(tmp_dirname, ignore_errors=True) # Remove local folder
512
+ [
513
+ gcs.delete_blob(blob_filepath)
514
+ for blob_filepath in gcs.list_blobs(dst_gcs_prefix)
515
+ ] # Remove temporary GCS files
516
+
517
+ logger.info(f"Export-download-combine done: {dst_filepath}")
518
+
519
+ # Limited csv rows
520
+ if csv_row_limit:
521
+ tmp_table_fqn: str | None = None
522
+ tmp_table_fqn_rn: str | None = None
523
+ try:
524
+ # Create temporary table
525
+ query_job = self.execute_query(query, temporary_table=True)
526
+ tmp_table_fqn = str(query_job.destination)
527
+ logger.debug(f"Create temp table: {tmp_table_fqn}")
528
+
529
+ # Create another temporary table for row numbering
530
+ query_job = self.execute_query(
531
+ f"SELECT *, ROW_NUMBER() OVER() AS _rn FROM `{tmp_table_fqn}`",
532
+ temporary_table=True,
533
+ )
534
+ tmp_table_fqn_rn = str(query_job.destination)
535
+ logger.debug(f"Create temp table (rn): {tmp_table_fqn_rn}")
536
+
537
+ # Process parts
538
+ count = list(
539
+ self.execute_query(
540
+ f"SELECT COUNT(1) FROM `{tmp_table_fqn_rn}`"
541
+ ).result()
542
+ )[0][0]
543
+ parts = math.ceil(count / csv_row_limit)
544
+ logger.info(f"Total part: {count} / {csv_row_limit} = {parts}")
545
+ dst_filepaths = []
546
+ for part in range(parts):
547
+ dst_filepath_part = (
548
+ f'{dst_filepath.removesuffix(".csv")}_{part + 1:06}.csv'
549
+ )
550
+ _export_download_combine(
551
+ f"SELECT * EXCEPT(_rn) FROM `{tmp_table_fqn_rn}` WHERE _rn BETWEEN {(part * csv_row_limit) + 1} AND {(part + 1) * csv_row_limit} ORDER BY _rn",
552
+ dst_gcs_prefix=gcs.build_tmp_dirpath(),
553
+ dst_filepath=dst_filepath_part,
554
+ )
555
+ dst_filepaths.append(dst_filepath_part)
556
+ return dst_filepaths
557
+ except:
558
+ raise
559
+ finally:
560
+ # Drop temporary tables
561
+ if tmp_table_fqn_rn:
562
+ self.drop_table(tmp_table_fqn_rn)
563
+ if tmp_table_fqn:
564
+ self.drop_table(tmp_table_fqn)
565
+
566
+ # Unlimited csv rows
567
+ else:
568
+ _export_download_combine(
569
+ query,
570
+ gcs.build_tmp_dirpath(),
571
+ dst_filepath,
572
+ query_parameters=query_parameters,
573
+ )
574
+ return dst_filepath
575
+
576
+ # query_job_result = query_job.result()
577
+ # row_count = 0
578
+ # file_index = 1
579
+
580
+ # # Stream-download-split result
581
+ # def open_file(f):
582
+ # if f:
583
+ # f.close()
584
+ # dst_filepath_part = f'{dst_filepath.removesuffix(".csv")}_{file_index:06}.csv' if row_limit else dst_filepath
585
+ # logger.info(f'Writing into file: {dst_filepath_part} ...')
586
+ # f = open(dst_filepath_part, 'w', newline='', encoding='utf-8')
587
+ # writer = csv.writer(f)
588
+ # writer.writerow([field.name for field in query_job_result.schema]) # Write header
589
+
590
+ # return f, writer
591
+
592
+ # f, writer = open_file(None)
593
+ # for row in query_job_result:
594
+ # writer.writerow(row)
595
+
596
+ # if row_limit:
597
+ # row_count += 1
598
+ # if row_count >= row_limit:
599
+ # row_count = 0
600
+ # file_index += 1
601
+ # f, writer = open_file(f)
602
+ # if f:
603
+ # f.close()
604
+
605
+ def download_xlsx(
606
+ self, src_table_fqn: str, dst_filename: str, xlsx_row_limit: int = 950000
607
+ ):
608
+ if not dst_filename.endswith(".xlsx"):
609
+ raise ValueError("Destination filename must ends with .xlsx!")
285
610
 
286
611
  # Create a temporary table acting as excel file splitting
287
- table_name_tmp = f'{src_table_fqn}_'
288
- self.execute_query(f'CREATE TABLE `{table_name_tmp}` AS SELECT *, ROW_NUMBER() OVER() AS _rn FROM `{src_table_fqn}`')
612
+ table_name_tmp = f"{src_table_fqn}_"
613
+ self.execute_query(
614
+ f"CREATE TABLE `{table_name_tmp}` AS SELECT *, ROW_NUMBER() OVER() AS _rn FROM `{src_table_fqn}`"
615
+ )
289
616
 
290
617
  try:
291
618
  # Calculate the number of excel file parts based on row limit
292
- cnt = list(self.execute_query(f'SELECT COUNT(1) AS cnt FROM `{src_table_fqn}`').result())[0][0]
619
+ cnt = list(
620
+ self.execute_query(
621
+ f"SELECT COUNT(1) AS cnt FROM `{src_table_fqn}`"
622
+ ).result()
623
+ )[0][0]
293
624
  parts = math.ceil(cnt / xlsx_row_limit)
294
- logger.debug(f'Total part: {cnt} / {xlsx_row_limit} = {parts}')
625
+ logger.debug(f"Total part: {cnt} / {xlsx_row_limit} = {parts}")
295
626
 
296
627
  # Download per parts
297
628
  for part in range(parts):
298
- logger.debug(f'Downloading part {part + 1}...')
299
- file_path_tmp = f'{dst_filename}_part{part + 1}'
300
- file_path_tmp_csv = f'{file_path_tmp}.csv'
301
- self.download_csv(f'SELECT * EXCEPT(_rn) FROM `{table_name_tmp}` WHERE _rn BETWEEN {(part * xlsx_row_limit) + 1} AND {(part + 1) * xlsx_row_limit}', f'{file_path_tmp}{os.sep}')
302
- csv_to_xlsx(file_path_tmp_csv, f'{file_path_tmp}.xlsx')
629
+ logger.debug(f"Downloading part {part + 1}...")
630
+ file_path_tmp = f"{dst_filename}_part{part + 1}"
631
+ file_path_tmp_csv = f"{file_path_tmp}.csv"
632
+ self.download_csv(
633
+ f"SELECT * EXCEPT(_rn) FROM `{table_name_tmp}` WHERE _rn BETWEEN {(part * xlsx_row_limit) + 1} AND {(part + 1) * xlsx_row_limit}",
634
+ f"{file_path_tmp}{os.sep}",
635
+ )
636
+ my_xlsx.csv_to_xlsx(file_path_tmp_csv, f"{file_path_tmp}.xlsx")
303
637
  os.remove(file_path_tmp_csv)
304
638
  except Exception as e:
305
639
  raise e
306
640
  finally:
307
- self.execute_query(f'DROP TABLE IF EXISTS `{table_name_tmp}`')
308
-
309
- def copy_table(self, src_table_id: str, dst_table_id: str, drop: bool = False):
310
- # Create or replace
311
- self.client.delete_table(dst_table_id, not_found_ok=True)
312
- self.client.copy_table(src_table_id, dst_table_id).result()
313
- logger.debug(f'Table {src_table_id} copied to {dst_table_id}')
314
-
315
- if drop:
316
- self.client.delete_table(src_table_id)
317
- logger.debug(f'Table {src_table_id} dropped')
318
-
319
- def copy_view(self, src_view_id: str, dst_view_id: str, drop: bool = False):
320
- src_project_id, src_dataset_id, _ = src_view_id.split('.')
321
- dst_project_id, dst_dataset_id, _ = dst_view_id.split('.')
322
-
323
- # Create or replace
324
- src_view = self.client.get_table(src_view_id)
325
- dst_view = bigquery.Table(dst_view_id)
326
- dst_view.view_query = src_view.view_query.replace(f'{src_project_id}.{src_dataset_id}', f'{dst_project_id}.{dst_dataset_id}')
327
- self.client.delete_table(dst_view, not_found_ok=True)
328
- self.client.create_table(dst_view)
329
- logger.debug(f'View {src_view_id} copied to {dst_view}')
330
-
331
- if drop:
332
- self.client.delete_table(src_view_id)
333
- logger.debug(f'View {src_view_id} dropped')
334
-
335
- def copy_routine(self, src_routine_id: str, dst_routine_id: str, drop: bool = False):
336
- src_project_id, src_dataset_id, _ = src_routine_id.split('.')
337
- dst_project_id, dst_dataset_id, _ = dst_routine_id.split('.')
338
-
339
- # Create or replace
340
- src_routine = self.client.get_routine(src_routine_id)
341
- dst_routine = bigquery.Routine(dst_routine_id)
342
- dst_routine.body = src_routine.body.replace(f'{src_project_id}.{src_dataset_id}', f'{dst_project_id}.{dst_dataset_id}')
343
- dst_routine.type_ = src_routine.type_
344
- dst_routine.description = src_routine.description
345
- dst_routine.language = src_routine.language
346
- dst_routine.arguments = src_routine.arguments
347
- dst_routine.return_type = src_routine.return_type
348
- self.client.delete_routine(dst_routine, not_found_ok=True)
349
- self.client.create_routine(dst_routine)
350
- logger.debug(f'Routine {src_routine_id} copied to {dst_routine_id}')
351
-
352
- if drop:
353
- self.client.delete_routine(src_routine_id)
354
- logger.debug(f'Routine {src_routine_id} dropped')
355
-
356
- def close_client(self):
641
+ self.execute_query(f"DROP TABLE IF EXISTS `{table_name_tmp}`")
642
+
643
+ # def copy_view(self, src_view_id: str, dst_view_id: str, drop: bool = False):
644
+ # src_project_id, src_dataset_id, _ = src_view_id.split('.')
645
+ # dst_project_id, dst_dataset_id, _ = dst_view_id.split('.')
646
+
647
+ # # Create or replace
648
+ # src_view = self.client.get_table(src_view_id)
649
+ # dst_view = bigquery.Table(dst_view_id)
650
+ # dst_view.view_query = src_view.view_query.replace(f'{src_project_id}.{src_dataset_id}', f'{dst_project_id}.{dst_dataset_id}')
651
+ # self.client.delete_table(dst_view, not_found_ok=True)
652
+ # self.client.create_table(dst_view)
653
+ # logger.debug(f'View {src_view_id} copied to {dst_view}')
654
+
655
+ # if drop:
656
+ # self.client.delete_table(src_view_id)
657
+ # logger.debug(f'View {src_view_id} dropped')
658
+
659
+ # def copy_routine(self, src_routine_id: str, dst_routine_id: str, drop: bool = False):
660
+ # src_project_id, src_dataset_id, _ = src_routine_id.split('.')
661
+ # dst_project_id, dst_dataset_id, _ = dst_routine_id.split('.')
662
+
663
+ # # Create or replace
664
+ # src_routine = self.client.get_routine(src_routine_id)
665
+ # dst_routine = bigquery.Routine(dst_routine_id)
666
+ # dst_routine.body = src_routine.body.replace(f'{src_project_id}.{src_dataset_id}', f'{dst_project_id}.{dst_dataset_id}')
667
+ # dst_routine.type_ = src_routine.type_
668
+ # dst_routine.description = src_routine.description
669
+ # dst_routine.language = src_routine.language
670
+ # dst_routine.arguments = src_routine.arguments
671
+ # dst_routine.return_type = src_routine.return_type
672
+ # self.client.delete_routine(dst_routine, not_found_ok=True)
673
+ # self.client.create_routine(dst_routine)
674
+ # logger.debug(f'Routine {src_routine_id} copied to {dst_routine_id}')
675
+
676
+ # if drop:
677
+ # self.client.delete_routine(src_routine_id)
678
+ # logger.debug(f'Routine {src_routine_id} dropped')
679
+
680
+ # MARK: Utilities
681
+
682
+ @staticmethod
683
+ def get_table_fqn_parts(name: str | list[str]) -> list[str] | list[list[str]]:
684
+ """Get fully qualified table name, following this format `<projectid>.<datasetid>.<tableid>`
685
+
686
+ Args:
687
+ name (str | list[str]): Input name (can be multiple)
688
+
689
+ Returns:
690
+ list[str] | list[list[str]]: The FQN parts. If the input is list then returns list of FQN parts instead.
691
+ """
692
+
693
+ if isinstance(name, list):
694
+ return [BQ.get_table_fqn_parts(x) for x in name]
695
+
696
+ split = name.split(".")
697
+ if len(split) == 3:
698
+ return split
699
+ else:
700
+ raise ValueError(f"{name} is not a valid table FQN")
701
+
702
+ @staticmethod
703
+ def raise_for_invalid_table_fqn(name: str | list[str]):
704
+ """Raise an error if the provied name is a fully qualified table name
705
+
706
+ Args:
707
+ name (str | list[str]): Input name (can be multiple)
708
+
709
+ Raises:
710
+ ValueError: If name is not a fully qualified table name
711
+ """
712
+
713
+ if not BQ.get_table_fqn_parts(name):
714
+ raise ValueError(f"{name} is not a valid table FQN")
715
+
716
+ def is_table_exists(self, table_fqn: str) -> bool:
717
+ self.raise_for_invalid_table_fqn(table_fqn)
718
+ try:
719
+ self.client.get_table(table_fqn)
720
+ return True
721
+ except NotFound:
722
+ return False
723
+
724
+ def close(self):
357
725
  self.client.close()
358
- logger.debug('BQ client close')
726
+ logger.debug("BQ client close")