rdxz2-utill 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rdxz2-utill might be problematic. Click here for more details.

utill/my_bq.py ADDED
@@ -0,0 +1,358 @@
1
+ import humanize
2
+ import math
3
+ import os
4
+ import shutil
5
+
6
+ from enum import Enum
7
+ from google.cloud import bigquery, storage
8
+ from loguru import logger
9
+ from textwrap import dedent
10
+
11
+ from .my_const import ByteSize
12
+ from .my_csv import read_header, combine as csv_combine, compress
13
+ from .my_datetime import current_datetime_str
14
+ from .my_env import envs
15
+ from .my_file import make_sure_path_is_directory
16
+ from .my_gcs import GCS
17
+ from .my_queue import ThreadingQ
18
+ from .my_string import replace_nonnumeric
19
+ from .my_xlsx import csv_to_xlsx
20
+
21
+ MAP__PYTHON_DTYPE__BQ_DTYPE = {
22
+ int: 'INTEGER',
23
+ str: 'STRING',
24
+ float: 'STRING',
25
+ }
26
+
27
+
28
+ class LoadStrategy(Enum):
29
+ OVERWRITE = 1
30
+ APPEND = 2
31
+
32
+
33
+ class Dtype:
34
+ INT64 = 'INT64'
35
+ INTEGER = 'INTEGER'
36
+ FLOAT64 = 'FLOAT64'
37
+
38
+ DECIMAL = 'DECIMAL'
39
+
40
+ STRING = 'STRING'
41
+ JSON = 'JSON'
42
+
43
+ DATE = 'DATE'
44
+ TIME = 'TIME'
45
+ DATETIME = 'DATETIME'
46
+ TIMESTAMP = 'TIMESTAMP'
47
+
48
+ BOOL = 'BOOL'
49
+
50
+ ARRAY_INT64 = 'ARRAY<INT64>'
51
+ ARRAY_INTEGER = 'ARRAY<INTEGER>'
52
+ ARRAY_FLOAT64 = 'ARRAY<FLOAT64>'
53
+ ARRAY_STRING = 'ARRAY<STRING>'
54
+ ARRAY_JSON = 'ARRAY<JSON>'
55
+ ARRAY_DATE = 'ARRAY<DATE>'
56
+ ARRAY_DATETIME = 'ARRAY<DATETIME>'
57
+ ARRAY_TIMESTAMP = 'ARRAY<TIMESTAMP>'
58
+ ARRAY_BOOL = 'ARRAY<BOOL>'
59
+
60
+
61
+ class BQ():
62
+ def __init__(self, project: str = None):
63
+ self.project = project or envs.GCP_PROJECT_ID
64
+
65
+ self.client = bigquery.Client(project=self.project)
66
+ logger.debug(f'BQ client open, project: {self.project or "<application-default>"}')
67
+
68
+ def __enter__(self):
69
+ return self
70
+
71
+ def __exit__(self, exc_type, exc_value, exc_tb):
72
+ self.close_client()
73
+
74
+ def execute_query(self, query: str | list[str], dry_run: bool = False, parameters: dict = {}) -> bigquery.QueryJob:
75
+ multi = type(query) == list
76
+ if multi:
77
+ query = '\n'.join([x if str(x).strip().endswith(';') else x + ';' for x in query if x])
78
+ else:
79
+ query = query.strip()
80
+
81
+ # Build paramters
82
+ query_parameters = []
83
+ for parameter, value in parameters.items():
84
+ if type(value) == list:
85
+ query_parameters.append(bigquery.ArrayQueryParameter(parameter, MAP__PYTHON_DTYPE__BQ_DTYPE[type(value[0])], value))
86
+ else:
87
+ query_parameters.append(bigquery.ScalarQueryParameter(parameter, MAP__PYTHON_DTYPE__BQ_DTYPE[type(value)], value))
88
+
89
+ logger.debug(f'🔎 Query:\n{query}')
90
+ query_job_config = bigquery.QueryJobConfig(dry_run=dry_run, query_parameters=query_parameters)
91
+ query_job = self.client.query(query, job_config=query_job_config)
92
+ query_job.result() # Wait query execution
93
+
94
+ if not multi:
95
+ logger.debug(f'[Job ID] {query_job.job_id}, [Processed] {humanize.naturalsize(query_job.total_bytes_processed)}, [Billed] {humanize.naturalsize(query_job.total_bytes_billed)}, [Affected] {query_job.num_dml_affected_rows or 0} row(s)',)
96
+ else:
97
+ logger.debug(f'[Job ID] {query_job.job_id}')
98
+
99
+ jobs: list[bigquery.QueryJob] = self.client.list_jobs(parent_job=query_job.job_id)
100
+ [logger.debug(f'[Script ID] {job.job_id}, [Processed] {humanize.naturalsize(job.total_bytes_processed)}, [Billed] {humanize.naturalsize(job.total_bytes_billed)}, [Affected] {job.num_dml_affected_rows or 0} row(s)',) for job in jobs]
101
+
102
+ return query_job
103
+
104
+ def create_table(self, bq_table_fqn: str, schema: list[bigquery.SchemaField], partition_col: str, cluster_cols: list[str]):
105
+ table = bigquery.Table(bq_table_fqn, schema=schema)
106
+
107
+ if partition_col:
108
+ table.time_partitioning = bigquery.TimePartitioning(field=partition_col)
109
+ table.partitioning_type = 'DAY'
110
+
111
+ if cluster_cols:
112
+ table.clustering_fields = cluster_cols
113
+
114
+ bq_table = self.client.create_table(table)
115
+ logger.info(f'✅ Table created: {bq_table_fqn}')
116
+ return bq_table
117
+
118
+ def drop_table(self, bq_table_fqn: str):
119
+ self.client.delete_table(bq_table_fqn)
120
+ logger.info(f'✅ Table dropped: {bq_table_fqn}')
121
+
122
+ def load_data_into(self, bq_table_fqn: str, gcs_path: list[str] | str, cols: dict[str, Dtype], partition_col: str = None, cluster_cols: list[str] = None, overwrite: bool = False):
123
+ if type(gcs_path) == str:
124
+ gcs_path = [gcs_path]
125
+ gcs_path_str = ',\n'.join([f' \'{x}\'' for x in gcs_path])
126
+
127
+ load_data_keyword = 'OVERWRITE' if overwrite else 'INTO'
128
+ cols_str = ',\n'.join([f' `{x}` {y}' for x, y in cols.items()])
129
+ cluster_cols_str = ','.join([f'`{x}`' for x in cluster_cols]) if cluster_cols else None
130
+ query = dedent(
131
+ f'''
132
+ LOAD DATA {load_data_keyword} `{bq_table_fqn}` (
133
+ {cols_str}
134
+ )
135
+ {f"PARTITION BY `{partition_col}`" if partition_col is not None else "-- No partition column provided"}
136
+ {f"CLUSTER BY {cluster_cols_str}" if cluster_cols_str is not None else "-- No cluster column provided"}
137
+ FROM FILES(
138
+ skip_leading_rows=1,
139
+ allow_quoted_newlines=true,
140
+ format='csv',
141
+ compression='gzip',
142
+ uris = [
143
+ {gcs_path_str}
144
+ ]
145
+ );
146
+ '''
147
+ )
148
+
149
+ logger.debug(f'⌛ Load data into: {bq_table_fqn}')
150
+ query_job = self.execute_query(query)
151
+ logger.info(f'✅ Load data into: {bq_table_fqn}')
152
+ return query_job
153
+
154
+ def export_data(self, query: str, gcs_path: str, pre_query: str = None):
155
+ if '*' not in gcs_path:
156
+ raise ValueError('GCS path need to have a single \'*\' wildcard character')
157
+
158
+ query = dedent(
159
+ f'''
160
+ EXPORT DATA OPTIONS (
161
+ uri='{gcs_path}',
162
+ format='csv',
163
+ compression='gzip',
164
+ overwrite=true,
165
+ header=true,
166
+ field_delimiter=',')
167
+ AS (
168
+ {query}
169
+ );
170
+ '''
171
+ )
172
+
173
+ if pre_query:
174
+ query = [pre_query, query]
175
+
176
+ logger.debug(f'⌛ Export data into: {gcs_path}')
177
+ query_job = self.execute_query(query)
178
+ logger.info(f'✅ Exported data into: {gcs_path}')
179
+ return query_job
180
+
181
+ def upload_csv(self, src_filename: str, bq_table_fqn: str, cols: dict[str, Dtype], partition_col: str = None, cluster_cols: list[str] = None, load_strategy: LoadStrategy = LoadStrategy.APPEND):
182
+ # <<----- START: Validation
183
+
184
+ if load_strategy not in LoadStrategy:
185
+ raise ValueError('Invalid load strategy')
186
+
187
+ if not src_filename.endswith('.csv'):
188
+ raise ValueError('Please provide file path with .csv extension!')
189
+
190
+ if partition_col is not None:
191
+ if partition_col not in cols.keys():
192
+ raise ValueError(f'Partition \'{partition_col}\' not exists in columns!')
193
+ if cluster_cols is not None:
194
+ if cluster_cols not in cols.keys():
195
+ raise ValueError(f'Cluster \'{cluster_cols}\' not exists in columns!')
196
+
197
+ # Build list of columns with its datatypes
198
+ csv_cols = set(read_header(src_filename))
199
+ excessive_cols = set(cols.keys()) - set(csv_cols)
200
+ if excessive_cols:
201
+ raise ValueError(f'{len(excessive_cols)} columns not exists in CSV file: {", ".join(excessive_cols)}')
202
+ nonexistent_cols = set(csv_cols) - set(cols.keys())
203
+ if nonexistent_cols:
204
+ raise ValueError(f'{len(nonexistent_cols)} columns from CSV are missing: {", ".join(nonexistent_cols)}')
205
+
206
+ # END: Validation ----->>
207
+
208
+ # <<----- START: Upload to GCS
209
+
210
+ gcs = GCS(self.project)
211
+ tmp_dir = f'tmp/upload__{current_datetime_str()}'
212
+
213
+ # This will compress while splitting the compressed file to a certain bytes size because of GCS 4GB file limitation
214
+ # A single file can produce more than one compressed file in GCS
215
+ def producer(src_file: str):
216
+ for dst_file in compress(src_file, keep=True, max_size_bytes=ByteSize.GB * 3):
217
+ yield (dst_file, )
218
+
219
+ def consumer(dst_file: str):
220
+ remote_file_name = f'{tmp_dir}/{replace_nonnumeric(os.path.basename(dst_file), "_").lower()}.csv.gz'
221
+ logger.debug(f'Uploading {dst_file} to {remote_file_name}...')
222
+ blob = gcs.upload(dst_file, remote_file_name, mv=True)
223
+ return blob
224
+
225
+ blobs: list[storage.Blob]
226
+ _, blobs = ThreadingQ().add_producer(producer, src_filename).add_consumer(consumer).execute()
227
+
228
+ # END: Upload to GCS ----->>
229
+
230
+ # <<----- START: Load to BQ
231
+
232
+ try:
233
+ gcs_filename_fqns = [f'gs://{blob.bucket.name}/{blob.name}' for blob in blobs]
234
+ match load_strategy:
235
+ case LoadStrategy.OVERWRITE:
236
+ self.load_data_into(bq_table_fqn, gcs_filename_fqns, cols, partition_col=partition_col, cluster_cols=cluster_cols, overwrite=True)
237
+ case LoadStrategy.APPEND:
238
+ self.load_data_into(bq_table_fqn, gcs_filename_fqns, cols, partition_col=partition_col, cluster_cols=cluster_cols)
239
+ case _:
240
+ return ValueError(f'Load strategy not recognized: {load_strategy}')
241
+ except Exception as e:
242
+ raise e
243
+ finally:
244
+ [GCS.remove_blob(blob) for blob in blobs]
245
+
246
+ # END: Load to BQ ----->>
247
+
248
+ def download_csv(self, query: str, dst_filename: str, combine: bool = True, pre_query: str = None):
249
+ if not dst_filename.endswith('.csv'):
250
+ raise ValueError('Destination filename must ends with .csv!')
251
+
252
+ dirname = os.path.dirname(dst_filename)
253
+ make_sure_path_is_directory(dirname)
254
+
255
+ # Remove & recreate existing folder
256
+ if os.path.exists(dirname):
257
+ shutil.rmtree(dirname)
258
+ os.makedirs(dirname, exist_ok=True)
259
+
260
+ # Export data into GCS
261
+ current_time = current_datetime_str()
262
+ gcs_path = f'gs://{envs.GCS_BUCKET}/tmp/unload__{current_time}/*.csv.gz'
263
+ self.export_data(query, gcs_path, pre_query)
264
+
265
+ # Download into local machine
266
+ gcs = GCS(self.project)
267
+ logger.info('Downloads from GCS...')
268
+ downloaded_filenames = []
269
+ for blob in gcs.list(f'tmp/unload__{current_time}/'):
270
+ file_path_part = os.path.join(dirname, blob.name.split('/')[-1])
271
+ gcs.download(blob, file_path_part)
272
+ downloaded_filenames.append(file_path_part)
273
+
274
+ # Combine the file and clean up the file chunks
275
+ if combine:
276
+ logger.info('Combine downloaded csv...')
277
+ csv_combine(downloaded_filenames, dst_filename)
278
+ shutil.rmtree(dirname)
279
+
280
+ return dst_filename
281
+
282
+ def download_xlsx(self, src_table_fqn: str, dst_filename: str, xlsx_row_limit: int = 950000):
283
+ if not dst_filename.endswith('.xlsx'):
284
+ raise ValueError('Destination filename must ends with .xlsx!')
285
+
286
+ # Create a temporary table acting as excel file splitting
287
+ table_name_tmp = f'{src_table_fqn}_'
288
+ self.execute_query(f'CREATE TABLE `{table_name_tmp}` AS SELECT *, ROW_NUMBER() OVER() AS _rn FROM `{src_table_fqn}`')
289
+
290
+ try:
291
+ # Calculate the number of excel file parts based on row limit
292
+ cnt = list(self.execute_query(f'SELECT COUNT(1) AS cnt FROM `{src_table_fqn}`').result())[0][0]
293
+ parts = math.ceil(cnt / xlsx_row_limit)
294
+ logger.debug(f'Total part: {cnt} / {xlsx_row_limit} = {parts}')
295
+
296
+ # Download per parts
297
+ for part in range(parts):
298
+ logger.debug(f'Downloading part {part + 1}...')
299
+ file_path_tmp = f'{dst_filename}_part{part + 1}'
300
+ file_path_tmp_csv = f'{file_path_tmp}.csv'
301
+ self.download_csv(f'SELECT * EXCEPT(_rn) FROM `{table_name_tmp}` WHERE _rn BETWEEN {(part * xlsx_row_limit) + 1} AND {(part + 1) * xlsx_row_limit}', f'{file_path_tmp}{os.sep}')
302
+ csv_to_xlsx(file_path_tmp_csv, f'{file_path_tmp}.xlsx')
303
+ os.remove(file_path_tmp_csv)
304
+ except Exception as e:
305
+ raise e
306
+ finally:
307
+ self.execute_query(f'DROP TABLE IF EXISTS `{table_name_tmp}`')
308
+
309
+ def copy_table(self, src_table_id: str, dst_table_id: str, drop: bool = False):
310
+ # Create or replace
311
+ self.client.delete_table(dst_table_id, not_found_ok=True)
312
+ self.client.copy_table(src_table_id, dst_table_id).result()
313
+ logger.debug(f'Table {src_table_id} copied to {dst_table_id}')
314
+
315
+ if drop:
316
+ self.client.delete_table(src_table_id)
317
+ logger.debug(f'Table {src_table_id} dropped')
318
+
319
+ def copy_view(self, src_view_id: str, dst_view_id: str, drop: bool = False):
320
+ src_project_id, src_dataset_id, _ = src_view_id.split('.')
321
+ dst_project_id, dst_dataset_id, _ = dst_view_id.split('.')
322
+
323
+ # Create or replace
324
+ src_view = self.client.get_table(src_view_id)
325
+ dst_view = bigquery.Table(dst_view_id)
326
+ dst_view.view_query = src_view.view_query.replace(f'{src_project_id}.{src_dataset_id}', f'{dst_project_id}.{dst_dataset_id}')
327
+ self.client.delete_table(dst_view, not_found_ok=True)
328
+ self.client.create_table(dst_view)
329
+ logger.debug(f'View {src_view_id} copied to {dst_view}')
330
+
331
+ if drop:
332
+ self.client.delete_table(src_view_id)
333
+ logger.debug(f'View {src_view_id} dropped')
334
+
335
+ def copy_routine(self, src_routine_id: str, dst_routine_id: str, drop: bool = False):
336
+ src_project_id, src_dataset_id, _ = src_routine_id.split('.')
337
+ dst_project_id, dst_dataset_id, _ = dst_routine_id.split('.')
338
+
339
+ # Create or replace
340
+ src_routine = self.client.get_routine(src_routine_id)
341
+ dst_routine = bigquery.Routine(dst_routine_id)
342
+ dst_routine.body = src_routine.body.replace(f'{src_project_id}.{src_dataset_id}', f'{dst_project_id}.{dst_dataset_id}')
343
+ dst_routine.type_ = src_routine.type_
344
+ dst_routine.description = src_routine.description
345
+ dst_routine.language = src_routine.language
346
+ dst_routine.arguments = src_routine.arguments
347
+ dst_routine.return_type = src_routine.return_type
348
+ self.client.delete_routine(dst_routine, not_found_ok=True)
349
+ self.client.create_routine(dst_routine)
350
+ logger.debug(f'Routine {src_routine_id} copied to {dst_routine_id}')
351
+
352
+ if drop:
353
+ self.client.delete_routine(src_routine_id)
354
+ logger.debug(f'Routine {src_routine_id} dropped')
355
+
356
+ def close_client(self):
357
+ self.client.close()
358
+ logger.debug('BQ client close')
utill/my_const.py ADDED
@@ -0,0 +1,18 @@
1
+ from enum import Enum
2
+
3
+
4
+ class ByteSize:
5
+ KB = 1_024
6
+ MB = 1_048_576
7
+ GB = 1_073_741_824
8
+ TB = 1_099_511_627_776
9
+
10
+
11
+ class HttpMethod(Enum):
12
+ GET = 1
13
+ POST = 2
14
+ PUT = 3
15
+ DELETE = 4
16
+
17
+ def __str__(self):
18
+ return self.name
utill/my_csv.py ADDED
@@ -0,0 +1,90 @@
1
+ import csv
2
+ import gzip
3
+ import os
4
+ import sys
5
+
6
+ from loguru import logger
7
+
8
+ from .my_const import ByteSize
9
+ from .my_file import decompress
10
+
11
+
12
+ def read_header(filename: str):
13
+ filename = os.path.expanduser(filename)
14
+ with open(filename, 'r') as f:
15
+ csvreader = csv.reader(f)
16
+ return next(csvreader)
17
+
18
+
19
+ def write(filename: str, rows: list[tuple], append: bool = False):
20
+ filename = os.path.expanduser(filename)
21
+ with open(filename, 'a' if append else 'w') as f:
22
+ csvwriter = csv.writer(f)
23
+ csvwriter.writerows(rows)
24
+
25
+
26
+ def compress(src_filename: str, keep: bool = False, max_size_bytes=ByteSize.GB, src_fopen=None, header=None, file_count=1):
27
+ src_filename = os.path.expanduser(src_filename)
28
+ current_size = 0
29
+ dst_filename = f'{src_filename}_part{str(file_count).rjust(6, "0")}.gz'
30
+ os.remove(dst_filename) if os.path.exists(dst_filename) else None
31
+ logger.debug(f'📄 Compress csv {src_filename} --> {dst_filename}')
32
+ gz = gzip.open(dst_filename, 'wt')
33
+
34
+ src_fopen = src_fopen or open(src_filename)
35
+ header = header or src_fopen.readline()
36
+
37
+ gz.write(header)
38
+
39
+ while True:
40
+ line = src_fopen.readline()
41
+ if not line:
42
+ break
43
+
44
+ gz.write(line)
45
+ current_size += len(line.encode('utf-8'))
46
+
47
+ if current_size >= max_size_bytes:
48
+ gz.close()
49
+ yield dst_filename
50
+
51
+ file_count += 1
52
+ yield from compress(src_filename, keep, max_size_bytes, src_fopen, header, file_count)
53
+ return
54
+
55
+ gz.close()
56
+ os.remove(src_filename) if not keep else None
57
+ yield dst_filename
58
+
59
+
60
+ def combine(src_filenames: list[str], dst_filename: str) -> None:
61
+ csv.field_size_limit(min(sys.maxsize, 2147483646)) # FIX: _csv.Error: field larger than field limit (131072)
62
+
63
+ if not dst_filename.endswith('.csv'):
64
+ raise ValueError('Output filename must ends with \'.csv\'!')
65
+
66
+ first_file = True
67
+ with open(dst_filename, 'w') as fout:
68
+ csvwriter = csv.writer(fout)
69
+
70
+ for src_filename in src_filenames:
71
+ src_filename = os.path.expanduser(src_filename)
72
+
73
+ # Decompress gzipped csv
74
+ if src_filename.endswith('.csv.gz'):
75
+ src_filename = decompress(src_filename)
76
+
77
+ # Copy
78
+ with open(src_filename, 'r') as fin:
79
+ csvreader = csv.reader(fin)
80
+
81
+ # Copy the header if this is the first file
82
+ if first_file:
83
+ csvwriter.writerow(next(csvreader))
84
+ # Else, skip the header
85
+ else:
86
+ next(csvreader)
87
+
88
+ [csvwriter.writerow(row) for row in csvreader]
89
+
90
+ logger.info(f'✅ Combine {src_filename}')
utill/my_datetime.py ADDED
@@ -0,0 +1,63 @@
1
+ from datetime import date, datetime, timedelta
2
+ from enum import Enum
3
+
4
+
5
+ class Level(Enum):
6
+ DAY = 1
7
+ MONTH = 2
8
+
9
+
10
+ def get_current_date_str(use_separator: bool = False) -> str:
11
+ return datetime.now().strftime('%Y-%m-%d' if use_separator else '%Y%m%d')
12
+
13
+
14
+ def current_datetime_str(use_separator: bool = False) -> str:
15
+ return datetime.now().strftime('%Y-%m-%d %H:%M:%S' if use_separator else '%Y%m%d%H%M%S')
16
+
17
+
18
+ def get_month_first_and_last_day(string: str) -> tuple:
19
+ try:
20
+ dt = datetime.strptime(string, '%Y-%m')
21
+ except ValueError:
22
+ dt = datetime.strptime(string, '%Y-%m-%d').replace(day=1)
23
+
24
+ return (dt, (dt + timedelta(days=32)).replace(day=1) - timedelta(days=1))
25
+
26
+
27
+ def generate_dates(start_date: date | str, end_date: date | str, level: Level, is_output_strings: bool = False):
28
+ # Auto convert strings
29
+ if type(start_date) == str:
30
+ start_date = datetime.strptime(start_date, '%Y-%m-%d').date()
31
+ if type(end_date) == str:
32
+ end_date = datetime.strptime(end_date, '%Y-%m-%d').date()
33
+
34
+ # Auto convert datetime
35
+ if type(start_date) == datetime:
36
+ start_date = start_date.date()
37
+ if type(end_date) == datetime:
38
+ end_date = end_date.date()
39
+
40
+ if start_date > end_date:
41
+ raise ValueError(f'start_date \'{start_date}\' cannot be larger than end_date \'{end_date}\'')
42
+
43
+ dates: list[date] = []
44
+
45
+ match level:
46
+ case Level.DAY:
47
+ while end_date >= start_date:
48
+ dates.append(end_date)
49
+ end_date = end_date - timedelta(days=1)
50
+ case Level.MONTH:
51
+ start_date = start_date.replace(day=1)
52
+ end_date = end_date.replace(day=1)
53
+ while end_date >= start_date:
54
+ end_date = end_date.replace(day=1)
55
+ dates.append(end_date)
56
+ end_date = end_date - timedelta(days=1)
57
+ case _:
58
+ raise ValueError(f'level \'{level}\' not recognized. available levels are: \'day\', \'month\'')
59
+
60
+ if is_output_strings:
61
+ return sorted([date.strftime('%Y-%m-%d') for date in dates])
62
+ else:
63
+ return sorted(dates)
utill/my_dict.py ADDED
@@ -0,0 +1,12 @@
1
+ class AutoPopulatingDict(dict):
2
+ def __init__(self, fetch_function, *args, **kwargs):
3
+ super().__init__(*args, **kwargs)
4
+ self.fetch_function = fetch_function
5
+
6
+ def __getitem__(self, key):
7
+ try:
8
+ return super().__getitem__(key)
9
+ except KeyError:
10
+ value = self.fetch_function(key)
11
+ self[key] = value
12
+ return value
utill/my_encryption.py ADDED
@@ -0,0 +1,52 @@
1
+ import os
2
+
3
+ from cryptography.fernet import Fernet
4
+ from loguru import logger
5
+
6
+
7
+ def __fernet_encrypt_or_decrypt(encrypt: bool, string: str, password: str):
8
+ return Fernet(password).encrypt(string.encode()) if encrypt else Fernet(password).encrypt(string.encode())
9
+
10
+
11
+ def __file_encrypt_or_decrypt(encrypt: bool, src_filename: str, password: str, dst_filename: str = None, overwrite: bool = False):
12
+ src_filename = os.path.expanduser(src_filename)
13
+
14
+ if not os.path.exists(src_filename):
15
+ return ValueError(f'Source file not exists: {src_filename}')
16
+
17
+ with open(src_filename, 'r') as fr:
18
+ # If destination file is not specified, return the encrypted string
19
+ if not dst_filename:
20
+ return __fernet_encrypt_or_decrypt(encrypt, fr.read(), password)
21
+ # If destination file is specified, encrypt into the destination file and return the file name
22
+ else:
23
+ dst_filename = os.path.expanduser(dst_filename)
24
+
25
+ # Destination file exists checker
26
+ if os.path.exists(dst_filename):
27
+ if overwrite:
28
+ return ValueError(f'Destination file exists: {dst_filename}')
29
+ else:
30
+ os.remove(dst_filename)
31
+
32
+ with open(dst_filename, 'w') as fw:
33
+ fw.write(__fernet_encrypt_or_decrypt(encrypt, fr.read()), password)
34
+
35
+ logger.info(f'Encrypted into {dst_filename}')
36
+ return dst_filename
37
+
38
+
39
+ def encrypt_file(src_filename: str, password: str, dst_filename: str = None, overwrite: bool = False) -> str:
40
+ return __file_encrypt_or_decrypt(True, src_filename, password, dst_filename, overwrite)
41
+
42
+
43
+ def decrypt_file(src_filename: str, password: str, dst_filename: str = None, overwrite: bool = False) -> str:
44
+ return __file_encrypt_or_decrypt(False, src_filename, password, dst_filename, overwrite)
45
+
46
+
47
+ def encrypt_string(string: str, password: str) -> str:
48
+ return __fernet_encrypt_or_decrypt(True, string, password)
49
+
50
+
51
+ def decrypt_string(string: str, password: str) -> str:
52
+ return __fernet_encrypt_or_decrypt(False, string, password)
utill/my_env.py ADDED
@@ -0,0 +1,66 @@
1
+ import os
2
+ import shutil
3
+
4
+ from loguru import logger
5
+ from pydantic_settings import BaseSettings
6
+ from typing import Optional
7
+
8
+ from .my_input import ask_yes_no
9
+
10
+ ENV_DIR = os.path.expanduser(os.path.join('~', '.utill'))
11
+ ENV_FILE = os.path.join(ENV_DIR, 'env')
12
+
13
+ TEMPLATE_DIR = 'templates'
14
+ TEMPLATE_PG_FILENAME = os.path.join(os.path.dirname(__file__), TEMPLATE_DIR, 'pg.json') # PostgreSQL connections
15
+ TEMPLATE_MB_FILENAME = os.path.join(os.path.dirname(__file__), TEMPLATE_DIR, 'mb.json') # Metabase connections
16
+
17
+ PG_FILENAME = os.path.join(ENV_DIR, os.path.basename(TEMPLATE_PG_FILENAME))
18
+ MB_FILENAME = os.path.join(ENV_DIR, os.path.basename(TEMPLATE_MB_FILENAME))
19
+
20
+ # Make sure env dir always exists
21
+ if not os.path.exists(ENV_DIR):
22
+ os.mkdir(ENV_DIR)
23
+
24
+
25
+ def init_pg_file():
26
+ if os.path.exists(PG_FILENAME):
27
+ if ask_yes_no(f'PostgreSQL connection file exists: {PG_FILENAME}, overwrite?'):
28
+ shutil.copy(TEMPLATE_PG_FILENAME, PG_FILENAME)
29
+ logger.warning(f'PostgreSQL connection file overwritten! {PG_FILENAME}')
30
+ else:
31
+ return
32
+
33
+ shutil.copy(TEMPLATE_PG_FILENAME, PG_FILENAME)
34
+ logger.info(f'PostgreSQL connection file created: {PG_FILENAME}')
35
+
36
+
37
+ def init_mb_file():
38
+ if os.path.exists(MB_FILENAME):
39
+ if ask_yes_no(f'Metabase connection file exists: {MB_FILENAME}, overwrite?'):
40
+ shutil.copy(TEMPLATE_MB_FILENAME, MB_FILENAME)
41
+ logger.warning(f'Metabase connection file overwritten! {MB_FILENAME}')
42
+ else:
43
+ return
44
+
45
+ shutil.copy(TEMPLATE_MB_FILENAME, MB_FILENAME)
46
+ logger.info(f'Metabase connection file created: {MB_FILENAME}')
47
+
48
+
49
+ class Envs(BaseSettings):
50
+
51
+ GCP_PROJECT_ID: Optional[str] = None
52
+ GCS_BUCKET: Optional[str] = None
53
+
54
+ def set_var(self, k: str, v: str):
55
+ setattr(self, k, v)
56
+
57
+ def write(self):
58
+ with open(ENV_FILE, 'w') as f:
59
+ data = '\n'.join(['{}=\"{}\"'.format(k, str(getattr(self, k)).replace('\"', '\\\"')) for k in self.model_fields.keys()])
60
+ f.write(data)
61
+
62
+ class Config:
63
+ env_file = ENV_FILE
64
+
65
+
66
+ envs = Envs()