mb_sql 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mb/sql/__init__.py ADDED
File without changes
mb/sql/conn.py ADDED
@@ -0,0 +1,85 @@
1
+ import sqlalchemy as sa
2
+ from sqlalchemy.orm import sessionmaker,declarative_base
3
+ from mb.utils.logging import logg
4
+
5
+ __all__ = ['get_engine', 'get_session','get_base','get_metadata']
6
+
7
+
8
+ def get_engine(name:str ='postgresql+psycopg2' , db: str = 'postgres',
9
+ user: str = 'postgres', password: str = 'postgres',
10
+ host: str = 'localhost', port=5432 ,
11
+ logger=None, echo=False):
12
+ """Get a SQLAlchemy engine object.
13
+
14
+ Args:
15
+ name (str): Name of the database.
16
+ db (str): Database name.
17
+ user (str): Username for the database.
18
+ password (str): Password for the database.
19
+ host (str): Hostname for the database.
20
+ port (str): Port for the database. Default: 5432
21
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
22
+ echo (bool): Echo SQL statements to stdout. Default: False
23
+ Returns:
24
+ engine (sqlalchemy.engine.base.Engine): Engine object.
25
+
26
+ """
27
+ try:
28
+ engine = sa.create_engine(f'{name}://{user}:{password}@{host}:{port}/{db}', echo=echo)
29
+ logg.info(f'Engine created for {name} database.', logger=logger)
30
+ return engine
31
+ except Exception as e:
32
+ logg.error(f'Error creating engine for {name} database.', logger=logger)
33
+ raise e
34
+
35
+ def get_session(engine, logger=None):
36
+ """Get a SQLAlchemy session object.
37
+
38
+ Args:
39
+ engine (sqlalchemy.engine.base.Engine): Engine object.
40
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
41
+ Returns:
42
+ session (sqlalchemy.orm.session.Session): Session object.
43
+
44
+ """
45
+ try:
46
+ session = sessionmaker(bind=engine)()
47
+ logg.info(f'Session created for {engine.url.database} database.', logger=logger)
48
+ return session
49
+ except Exception as e:
50
+ logg.error(f'Error creating session for {engine.url.database} database.', logger=logger)
51
+ raise e
52
+
53
+ def get_base(logger=None):
54
+ """Get a SQLAlchemy declarative base object.
55
+
56
+ Args:
57
+ engine (sqlalchemy.engine.base.Engine): Engine object.
58
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
59
+ Returns:
60
+ base (sqlalchemy.ext.declarative.api.DeclarativeMeta): Declarative base object.
61
+ """
62
+ try:
63
+ Base = declarative_base()
64
+ logg.info('Base created for database.', logger=logger)
65
+ return Base
66
+ except Exception:
67
+ logg.error('Error creating base for database.', logger=logger)
68
+
69
+ def get_metadata(base,conn, logger=None):
70
+ """Get a SQLAlchemy metadata object.
71
+
72
+ Args:
73
+ base (sqlalchemy.ext.declarative.api.DeclarativeMeta): Declarative base object.
74
+ conn (sqlalchemy.engine.base.Connection): Connection object.
75
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
76
+ Returns:
77
+ metadata (sqlalchemy.sql.schema.MetaData): Metadata object.
78
+ """
79
+ try:
80
+ metadata = base.metadata.create_all(conn)
81
+ logg.info('Metadata created for database.', logger=logger)
82
+ return metadata
83
+ except Exception:
84
+ logg.error('Error creating metadata for database.', logger=logger)
85
+
mb/sql/slack.py ADDED
@@ -0,0 +1,24 @@
1
+
2
+ import requests
3
+ import json
4
+ from mb.utils.logging import logg
5
+
6
+ __all__ = ['slack_msg']
7
+
8
+ def slack_msg(webhook,msg,logger=None):
9
+ """
10
+ Send a message to a slack channel
11
+ Args:
12
+ webhook (str): Slack webhook URL
13
+ msg (str): Message to send
14
+ logger (logging.Logger): Logger to use
15
+ Returns:
16
+ None
17
+ """
18
+
19
+ response = requests.post(
20
+ url=webhook,
21
+ data=json.dumps(msg),
22
+ headers={'Content-Type': 'application/json'})
23
+
24
+ logg.info('Slack response: %s', response.text, logger=logger)
mb/sql/snapshot.py ADDED
@@ -0,0 +1,197 @@
1
+ '''
2
+ fix this file.
3
+ it hsould be able to take any snapshot of the current data and save it to s3 and vice versa.
4
+ ID should be given and stored in s3/DB as well along with time-data.
5
+ '''
6
+ from mb.utils import s3
7
+ from typing import Optional
8
+ import pandas as pd
9
+ from mb.utils.logging import logg
10
+ from mb.pandas.dfload import load_any_df
11
+
12
+ __api__ = [
13
+ "get_s3url",
14
+ "list_names",
15
+ "list_ids",
16
+ "latest_id",
17
+ "make_new_id",
18
+ "save",
19
+ "load",
20
+ ]
21
+
22
+
23
+ def get_s3url(df_name,id,location='s3://snapshots/dataframes/',file_type='csv'):
24
+ """Gets the s3url of a dataframe snapshot.
25
+
26
+ Parameters
27
+ ----------
28
+ df_name : str
29
+ dataframe name
30
+ id : int
31
+ snapshot id
32
+ location : str
33
+ s3cmd url to the snapshots folder
34
+ file_type : str
35
+ file type of the snapshot
36
+
37
+ Returns
38
+ -------
39
+ s3url : str
40
+ s3cmd url to the snapshot parquet file
41
+ """
42
+
43
+ return f"{location}/{df_name}/{id}.{file_type}"
44
+
45
+
46
+ def list_names(location='s3://snapshots/dataframes/'):
47
+ """Returns the list of names of the dataframes that have snapshots.
48
+
49
+ Parameters
50
+ ----------
51
+ location : str
52
+ s3cmd url to the snapshots folder
53
+
54
+ Returns
55
+ -------
56
+ list
57
+ list of dataframe names available in S3
58
+ """
59
+
60
+ s3url = f"{location}"
61
+ list_s3urls = s3.list_objects(s3url)
62
+ return list_s3urls
63
+
64
+ def list_ids(df_name: str, show_progress: bool = False) -> list:
65
+ """Returns the list of snapshots available for a given dataframe.
66
+
67
+ Parameters
68
+ ----------
69
+ df_name : str
70
+ dataframe name
71
+ show_progress : bool
72
+ show a progress spinner in the terminal
73
+
74
+ Returns
75
+ -------
76
+ list
77
+ list of snapshot ids available for the dataframe
78
+ """
79
+
80
+ s3url = f"s3://vision-ml-datasets/snapshots/dataframes/{df_name}/"
81
+ l_s3urls = s3.list_objects(s3url, show_progress=show_progress)
82
+ l_filenames = [x[len(s3url) :] for x in l_s3urls]
83
+ l_ids = [
84
+ int(x[2:-8])
85
+ for x in l_filenames
86
+ if x.startswith("sn") and x.endswith(".parquet")
87
+ ]
88
+ return l_ids
89
+
90
+
91
+ def latest_id(df_name: str) -> int:
92
+ """Returns the list of snapshots available for a given dataframe.
93
+
94
+ Parameters
95
+ ----------
96
+ df_name : str
97
+ dataframe name
98
+
99
+ Returns
100
+ -------
101
+ list
102
+ list of snapshot ids available for the dataframe
103
+
104
+ Raises
105
+ ------
106
+ IndexError
107
+ if there is not snapshot available
108
+ """
109
+
110
+ l_ids = list_ids(df_name, show_progress=False)
111
+ if l_ids:
112
+ return max(l_ids)
113
+ raise IndexError(f"No snapshot available for dataframe '{df_name}'.")
114
+
115
+
116
+ def make_new_id(df_name: str) -> int:
117
+ """Makes a new snapshot id based on the current date.
118
+
119
+ Parameters
120
+ ----------
121
+ df_name : str
122
+ dataframe name
123
+
124
+ Returns
125
+ -------
126
+ int
127
+ a snapshot id based on the current date
128
+ """
129
+ ts = pd.Timestamp.utcnow()
130
+ return (ts.year % 100) * 10000 + ts.month * 100 + ts.day
131
+
132
+
133
+ def save(
134
+ df: pd.DataFrame,
135
+ df_name: str,
136
+ sn_id: Optional[int] = None,
137
+ local_path: Optional[str] = './snapshot.parquet',
138
+ logger = None,
139
+ ):
140
+ """Snapshots a dataframe.
141
+
142
+ Parameters
143
+ ----------
144
+ df : pandas.DataFrame
145
+ the dataframe to be saved
146
+ df_name : str
147
+ dataframe name
148
+ sn_id : int, optional
149
+ snapshot id to which the dataframe will be saved. If not provided, the id will be
150
+ automatically generated.
151
+ local_path : str, optional
152
+ local path to save the snapshot. Default is './snapshot.parquet'
153
+ logger : optional
154
+ logger to log the snapshotting process. If not provided, no logging will be done.
155
+ """
156
+
157
+ if sn_id is None:
158
+ sn_id = make_new_id(df_name)
159
+
160
+ msg = f"Snapshotting dataframe {df_name} with id {sn_id}"
161
+ s3url_local_path = get_s3url(df_name, sn_id)
162
+ logg.info(msg, logger=logger)
163
+ file = load_any_df(s3url_local_path)
164
+ file.to_parquet(local_path, index=False)
165
+ return file
166
+
167
+
168
+ def load(
169
+ df_name: str,
170
+ sn_id: Optional[int] = None,
171
+ logger: Optional = None,
172
+ ) -> pd.DataFrame:
173
+ """Loads a dataframe snapshot.
174
+
175
+ Parameters
176
+ ----------
177
+ df_name : str
178
+ dataframe name
179
+ sn_id : int, optional
180
+ snapshot id to which the dataframe will be saved. If not provided, the id corresponding to
181
+ the latest snapshot will be used.
182
+
183
+ Returns
184
+ -------
185
+ pandas.DataFrame
186
+ the snapnotted dataframe
187
+ """
188
+
189
+ if sn_id is None:
190
+ sn_id = latest_id(df_name)
191
+
192
+ msg = f"Loading a dataframe {df_name} with snapshot id {sn_id}"
193
+ with logg.scoped_info(msg, logger=logger):
194
+ s3url = get_s3url(df_name, sn_id)
195
+ # filepath = cache(s3url, logger=logger)
196
+
197
+ # return pd.dfload(filepath, show_progress=bool(logger), unpack=False)
mb/sql/sql.py ADDED
@@ -0,0 +1,97 @@
1
+ import sqlalchemy as sa
2
+ import pandas as pd
3
+ from mb.utils.logging import logg
4
+
5
+ __all__ = ['read_sql','engine_execute','to_sql']
6
+
7
+ def trim_sql_query(sql_query: str) -> str:
8
+ """
9
+ Remove extra whitespace from a SQL query.
10
+ """
11
+ sql_query = " ".join(sql_query.splitlines())
12
+ sql_query = " ".join(sql_query.split())
13
+ return sql_query
14
+
15
+ def read_sql(query,engine,index_col=None,chunk_size=10000,logger=None):
16
+ """Read SQL query into a DataFrame.
17
+
18
+ Args:
19
+ engine (sqlalchemy.engine.base.Engine): Engine object.
20
+ query (str): SQL query.
21
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
22
+ Returns:
23
+ df (pandas.core.frame.DataFrame): DataFrame object.
24
+ """
25
+ try:
26
+
27
+ if isinstance(query, str):
28
+ query = trim_sql_query(query)
29
+ query = sa.text(query)
30
+ elif isinstance(query, sa.sql.selectable.Select):
31
+ query = query
32
+
33
+ if chunk_size==None or chunk_size==0:
34
+ with engine.begin() as conn:
35
+ df = pd.read_sql(query,conn,index_col=index_col)
36
+ return df
37
+
38
+ with engine.begin() as conn:
39
+ df= pd.DataFrame()
40
+ for chunk in pd.read_sql(query,conn,index_col=index_col,chunksize=chunk_size):
41
+ df = pd.concat([df,chunk],ignore_index=True)
42
+
43
+ logg.info('SQL query executed successfully.',logger=logger)
44
+ return df
45
+ except sa.exc.SQLAlchemyError as e:
46
+ logg.error('Error executing SQL query.',logger=logger)
47
+ raise e
48
+
49
+ def engine_execute(engine, query_str):
50
+ """
51
+ Execute a query on a SQLAlchemy engine object.
52
+
53
+ Args:
54
+ engine (sqlalchemy.engine.base.Engine): Engine object.
55
+ query_str (str): Query string.
56
+ Returns:
57
+ result (sqlalchemy.engine.result.ResultProxy): Result object.
58
+ """
59
+ if isinstance(query_str, str):
60
+ query = sa.text(query_str)
61
+ else:
62
+ query = query_str
63
+
64
+ if isinstance(engine, sa.engine.Engine):
65
+ with engine.begin() as conn:
66
+ return conn.execute(query)
67
+ elif isinstance(engine, sa.engine.Connection):
68
+ return engine.execute(query)
69
+
70
+
71
+ def to_sql(df,engine,table_name,schema=None,if_exists='replace',index=False,index_label=None,chunksize=10000,logger=None):
72
+ """Write records stored in a DataFrame to a SQL database.
73
+
74
+ Args:
75
+ df (pandas.core.frame.DataFrame): DataFrame object.
76
+ engine (sqlalchemy.engine.base.Engine): Engine object.
77
+ table_name (str): Name of the table.
78
+ schema (str): Name of the schema. Default: None
79
+ if_exists (str): How to behave if the table already exists. Default: 'replace'
80
+ index (bool): Write DataFrame index as a column. Default: False
81
+ index_label (str): Column label for index column(s). If None is given (default) and index is True, then the index names are used. A sequence should be given if the DataFrame uses MultiIndex. Default: None
82
+ chunksize (int): Number of rows to write at a time. Default: 10000
83
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
84
+ Returns:
85
+ None
86
+ """
87
+ try:
88
+ if index:
89
+ if index_label is None:
90
+ index_label = df.index.name
91
+ df.to_sql(table_name,engine,schema=schema,if_exists=if_exists,index=index,index_label=index_label,chunksize=chunksize)
92
+ logg.info(f'DataFrame written to {table_name} table.',logger=logger)
93
+ except Exception as e:
94
+ logg.error(f'Error writing DataFrame to {table_name} table.',logger=logger)
95
+ raise e
96
+
97
+
mb/sql/tables.py ADDED
@@ -0,0 +1,64 @@
1
+ ## Tables to be updated every night by the cron job
2
+
3
+ import sqlalchemy as sa
4
+ import typing as tp
5
+ from .conn import get_engine
6
+
7
+
8
+ class TableConfig:
9
+ """
10
+ Table configuration object.
11
+ """
12
+ def __init__(
13
+ self,
14
+ schema: str,
15
+ table: str,
16
+ index_col: str,
17
+ chunk_size: int,
18
+ updated_col: str,
19
+ dst_engine: str = "mb_public2",
20
+ dtype: tp.Optional[dict] = None,
21
+ ):
22
+ self.schema = schema
23
+ self.table = table
24
+ self.index_col = index_col
25
+ self.chunk_size = chunk_size
26
+ self.updated_col = updated_col
27
+ self.dst_engine = dst_engine
28
+ self.dtype = dtype
29
+
30
+ def get_src_engine(self):
31
+ if self.schema == "mb_public1":
32
+ self.src_engine = get_engine(name='postgresql' , db= 'postgres', user='postgres' , password= 'postgres', host= 'localhost', port= 5432, echo=False)
33
+ return self.src_engine
34
+
35
+ def get_dst_engine(self):
36
+ if self.schema == "mb_public2":
37
+ self.dst_engine =get_engine(name='postgresql' , db= 'postgres_2', user='postgres' , password= 'postgres', host= 'localhost', port= 5432, echo=False)
38
+ return self.dst_engine
39
+
40
+
41
+
42
+ mutable_tables = {
43
+ 'table_to_update1': TableConfig('mb_public1',
44
+ "test2",
45
+ None,
46
+ 10000,
47
+ None,
48
+ dtype={
49
+ 'name': sa.Text,
50
+ 'age' : sa.Integer
51
+ },
52
+ ),
53
+
54
+ 'table_to_update2': TableConfig('mb_public1',
55
+ "test3",
56
+ 'id',
57
+ 10000,
58
+ None,
59
+ dtype={
60
+ 'id':sa.Integer,
61
+ 'num': sa.Integer,
62
+ 'data': sa.Text
63
+ },)
64
+ }
mb/sql/update.py ADDED
@@ -0,0 +1,709 @@
1
+ import sqlalchemy as sa
2
+ import pandas as pd
3
+ import sqlalchemy.dialects.postgresql
4
+ from .sql import read_sql
5
+ from .utils import list_tables
6
+ import boto3
7
+
8
+ __all__ = ['get_last_updated_timestamp']
9
+
10
+
11
+ def get_last_updated_timestamp(engine,table_name: str,schema : str,updated_col: str = "updated") -> pd.Timestamp:
12
+ """
13
+ Gets the last updated timestamp from all records of an sqlite table containing the 'updated' field.
14
+
15
+ Args:
16
+ engine (sqlalchemy.engine.base.Engine): Engine object.
17
+ table_name (str): tables containing the 'updated' field.
18
+ schema (str): Name of the schema. Default: None
19
+ updated_col (str): name of the 'updated' field. Default: 'updated'
20
+ Returns:
21
+ ts (pandas.Timestamp): last updated timestamp
22
+ """
23
+
24
+ query_str = f"SELECT MAX({updated_col}) AS last_updated FROM {table_name};"
25
+ df = read_sql(query_str, engine)
26
+
27
+ ts = pd.Timestamp(df["last_updated"][0])
28
+ if pd.isnull(ts):
29
+ ts = pd.Timestamp("2014-01-01")
30
+ return ts
31
+
32
+ def get_last_updated_data_id(engine,table_name: str,schema : str,updated_col: str = "updated") -> pd.Timestamp:
33
+ """
34
+ Gets all the id , updated timestamp from all records of an sqlite table containing the 'updated' field.
35
+
36
+ Args:
37
+ engine (sqlalchemy.engine.base.Engine): Engine object.
38
+ table_name (str): tables containing the 'updated' field.
39
+ schema (str): Name of the schema. Default: None
40
+ updated_col (str): name of the 'updated' field. Default: 'updated'
41
+ Returns:
42
+ ts (pandas.Timestamp): last updated timestamp
43
+ """
44
+ query_str = f"SELECT id, {updated_col} FROM {table_name};"
45
+ df = read_sql(query_str, engine)
46
+
47
+ for row in df.iterrows():
48
+ if pd.isnull(row[updated_col]):
49
+ row[updated_col] = pd.Timestamp("2014-01-01")
50
+ return df
51
+
52
+
53
+ def get_new_updated_timestamps_and_id_from_s3(bucket_name,table_name, key, last_updated,
54
+ access_key='your_access_key',
55
+ secret_key='your_secret_key',region_name='your_region_name'):
56
+ s3 = boto3.client('s3', aws_access_key_id=access_key,
57
+ aws_secret_access_key=secret_key, region_name=region_name)
58
+ obj = s3.get_object(Bucket=bucket_name, Key=key)
59
+ data = pd.read_csv(obj['Body'])
60
+ new_data = data[data['last_updated'] > last_updated]
61
+ new_rows = []
62
+ for index, row in new_data.iterrows():
63
+ if row['id'] not in [r[0] for r in get_last_updated_data_id(table_name)]:
64
+ new_rows.append(row)
65
+ return new_rows, new_data.iloc[-1]['last_updated']
66
+
67
+ # Get new updated timestamps from S3
68
+ def get_new_updated_timestamps_from_s3(bucket_name, key, last_updated,
69
+ access_key='your_access_key',
70
+ secret_key='your_secret_key',region_name='your_region_name'):
71
+ """
72
+ Get new updated timestamps from S3 bucket
73
+
74
+ Args:
75
+ bucket_name (str): name of the S3 bucket
76
+ key (str): name of the file in the S3 bucket
77
+ last_updated (pandas.Timestamp): last updated timestamp
78
+ Returns:
79
+ new_data (pandas.DataFrame): dataframe containing 'updated' column.
80
+ """
81
+ s3 = boto3.client('s3', aws_access_key_id=access_key,
82
+ aws_secret_access_key=secret_key, region_name=region_name)
83
+ obj = s3.get_object(Bucket=bucket_name, Key=key)
84
+ data = pd.read_csv(obj['Body'])
85
+ new_data = data[data['last_updated'] > last_updated]
86
+ return new_data
87
+
88
+ #####
89
+ '''
90
+
91
+ # Update local table
92
+ def update(table_name, data):
93
+ for index, row in data.iterrows():
94
+ query = f"UPDATE {table_name} SET column1 = {row['column1']}, column2 = {row['column2']} WHERE id = {row['id']}"
95
+ cur.execute(query)
96
+
97
+ # Read sync via id
98
+ def readsync_via_id(table_name, id):
99
+ query = f"SELECT * from {table_name} WHERE id = {id}"
100
+ cur.execute(query)
101
+ rows = cur.fetchall()
102
+ return rows
103
+
104
+ # Merge data
105
+ def merge(table_name, data):
106
+ conn = create_conn()
107
+ cur = conn.cursor()
108
+ for index, row in data.iterrows():
109
+ query = f"INSERT INTO {table_name} (id, column1, column2) VALUES ({row['id']}, {row['column1']}, {row['column2']}) ON CONFLICT (id) DO UPDATE SET column1 = {row['column1']}, column2 = {row['column2']}"
110
+ cur.execute(query)
111
+ conn.commit()
112
+ conn.close()
113
+
114
+ # Get duplicates
115
+ def get_duplicates(table_name):
116
+ query = f"SELECT id, COUNT(*) FROM {table_name} GROUP BY id HAVING COUNT(*) > 1"
117
+ cur.execute(query)
118
+ rows = cur.fetchall()
119
+ return rows
120
+
121
+ # Get updated fields
122
+ def get_updated_fields(table_name, id):
123
+ query = f"SELECT column1, column2 from {table_name} WHERE id = {id}"
124
+ cur.execute(query)
125
+ rows = cur.fetchall()
126
+ return rows
127
+
128
+
129
+ ######
130
+
131
+
132
+ def get_new_updated_timestamps(mutable_name: str, mutable_table: dict, last_uts: pd.Timestamp, limit: int = 10000) -> pd.Timestamp:
133
+ """
134
+ Obtains all new timestamps and their counts from the 'updated' field of the table.
135
+
136
+ Args:
137
+ mutable_name (str) : name of the table
138
+ table (dict) : dict containing the 'mutable_tables'.
139
+ last_uts (pandas.Timestamp) : the last updated timestamp of the corresponding table
140
+ limit (int) : maximum number of timestamps to be queried
141
+
142
+ Returns:
143
+ df (pd.DataFrame) : dataframe containing 'updated', 'count' columns. The 'updated' field is sorted in ascending order
144
+ """
145
+
146
+ if mutable_name not in mutable_table:
147
+ raise ValueError("Unknown mutable with name '{}'.".format(mutable_name))
148
+
149
+ t = mutable_table[mutable_name]
150
+ remote_engine = t.get_src_engine()
151
+ if t.scheme:
152
+ final_table = f"{t.scheme}.{t.table}"
153
+ else:
154
+ final_table = t.table
155
+
156
+ query_str = "SELECT {updated_col}, COUNT(*) AS count FROM {} WHERE {updated_col} > '{}' GROUP BY {updated_col} ORDER BY {updated_col} LIMIT {};".format(
157
+ final_table, str(last_uts), limit, updated_col=t.updated_col
158
+ )
159
+ df = read_sql(query_str, remote_engine)
160
+
161
+ return df
162
+
163
+
164
+ def _to_sql(df, table_name, engine, dtype=None, from_mysql: bool = False, **kwargs):
165
+ if from_mysql and isinstance(dtype, dict):
166
+ for k, v in dtype.items():
167
+ if v != sa.Text:
168
+ continue
169
+ df[k] = df[k].apply(
170
+ lambda x: x.decode() if isinstance(x, (bytes, bytearray)) else x
171
+ )
172
+ return df.to_sql(table_name, engine, dtype=dtype, **kwargs)
173
+
174
+
175
+ def update(df: pd.DataFrame, mutable_name: str, mutable_table: dict, index_col: str, is_new_table: bool) -> int:
176
+ if mutable_name not in mutable_table:
177
+ raise ValueError("Unknown mutable with name '{}'.".format(mutable_name))
178
+
179
+ t = mutable_table[mutable_name]
180
+ engine = t.get_dst_engine()
181
+
182
+ with engine.begin() as conn: #
183
+ if not is_new_table:
184
+ query_str = ",".join((str(x) for x in df[index_col].tolist()))
185
+ query_str = "DELETE FROM {} WHERE {} IN ({});".format(
186
+ mutable_name, index_col, query_str
187
+ )
188
+ conn.execute(sa.text(query_str))
189
+
190
+ df = df.set_index(index_col, drop=True)
191
+ _to_sql(
192
+ df,
193
+ mutable_name,
194
+ conn,
195
+ if_exists="append",
196
+ dtype=t.dtype
197
+ )
198
+
199
+ return len(df)
200
+
201
+
202
+ def readsync_via_id(mutable_name: str,mutable_table: dict,logger=None):
203
+ """Read-sync a muv1db table containing the 'updated' field.
204
+
205
+ Parameters
206
+ ----------
207
+ mutable_name : str
208
+ one of the muv1db tables containing the 'updated' field. Must be a key of `mutable_map`
209
+ module variable.
210
+ logger : mt.logg.IndentedLoggerAdapter, optional
211
+ logger for debugging purposes
212
+ """
213
+
214
+ if mutable_name not in mutable_table:
215
+ raise ValueError("Unknown mutable with name '{}'.".format(mutable_name))
216
+
217
+ t = mutable_table[mutable_name]
218
+ engine = t.get_dst_engine()
219
+ schema = t.schema
220
+ table = t.table
221
+ index_col = t.index_col
222
+ chunk_size = t.chunk_size
223
+
224
+ remote_engine = t.get_src_engine()
225
+ if schema:
226
+ final_table = f"{schema}.{table}"
227
+ else:
228
+ final_table = table
229
+
230
+
231
+ if mutable_name in list_tables(engine):
232
+ is_new_table = False
233
+ query_str = "SELECT MAX({index_col}) AS last_id FROM {mutable_name}".format(
234
+ index_col=index_col, mutable_name=mutable_name
235
+ )
236
+ last_id = read_sql(query_str, engine)["last_id"][0]
237
+ else:
238
+ is_new_table = True
239
+ last_id = 0
240
+
241
+ msg = "Read-syncing mutable '{}' with last id {}".format(mutable_name, last_id)
242
+ with logg.scoped_info(msg, logger=logger):
243
+ first_id = last_id
244
+ query_str = "SELECT MAX({index_col}) AS last_id FROM {final_table}".format(
245
+ index_col=index_col, frame_str=final_table
246
+ )
247
+ last_id = read_sql(query_str, remote_engine)["last_id"][0]
248
+ if logger:
249
+ logger.debug("Remote has last id {}.".format(last_id))
250
+
251
+ query_str = "SELECT COUNT(*) AS cnt FROM {frame_str} WHERE {index_col} > {first_id} AND {index_col} <= {last_id};".format(
252
+ index_col=index_col, frame_str=final_table, first_id=first_id, last_id=last_id
253
+ )
254
+ count = read_sql(query_str, remote_engine)["cnt"][0]
255
+ if count == 0:
256
+ return
257
+
258
+ # determine new fold
259
+ fold = (count // chunk_size) + 1
260
+ if logger:
261
+ logger.info(
262
+ " Found {} records to be downloaded in {} chunks.".format(count, fold)
263
+ )
264
+
265
+ for i in range(fold):
266
+ start_ofs = first_id + (last_id - first_id) * i // fold
267
+ end_ofs = first_id + (last_id - first_id) * (i + 1) // fold
268
+
269
+ query_str = "SELECT * FROM {frame_str} WHERE {index_col}>{first_id} and {index_col}<={last_id};".format(
270
+ frame_str=frame_str,
271
+ index_col=index_col,
272
+ first_id=start_ofs,
273
+ last_id=end_ofs,
274
+ )
275
+ df = ss.read_sql(query_str, remote_engine)
276
+
277
+ if len(df) == 0:
278
+ continue
279
+
280
+ if logger:
281
+ logger.info(
282
+ " {}: Downloaded {} records with id in [{},{}].".format(
283
+ i + 1, len(df), start_ofs, end_ofs
284
+ )
285
+ )
286
+
287
+ update(df, mutable_name, index_col, is_new_table)
288
+
289
+
290
+ def readsync(mutable_name: str, logger=None):
291
+ """Read-sync a muv1db table containing the 'updated' field.
292
+
293
+ Parameters
294
+ ----------
295
+ mutable_name : str
296
+ one of the muv1db tables containing the 'updated' field. Must be a key of `mutable_map`
297
+ module variable.
298
+ logger : mt.logg.IndentedLoggerAdapter, optional
299
+ logger for debugging purposes
300
+ """
301
+
302
+ t = mutable_map[mutable_name]
303
+
304
+ if t.updated_col is None:
305
+ return readsync_via_id(mutable_name, logger=logger)
306
+
307
+ schema = t.schema
308
+ table = t.table
309
+ index_col = t.index_col
310
+ chunk_size = t.chunk_size
311
+ updated_col = t.updated_col
312
+
313
+ remote_engine = t.get_src_engine()
314
+ frame_str = ss.frame_sql(table, schema=schema)
315
+
316
+ last_uts = get_last_updated_timestamp(mutable_name, updated_col=updated_col)
317
+ is_new_table = last_uts < pd.Timestamp("2017-01-02")
318
+
319
+ msg = "Read-syncing mutable '{}' with timestamp '{}'".format(
320
+ mutable_name, str(last_uts)
321
+ )
322
+ with logg.scoped_info(msg, logger=logger):
323
+ count = 0
324
+ while True:
325
+ thresh_uts = pd.Timestamp.utcnow().tz_localize(None) - pd.Timedelta(
326
+ 1, "hours"
327
+ )
328
+ if last_uts > thresh_uts:
329
+ break
330
+
331
+ # large enough so that the probing+sorting op is neglible compared to the downloading op
332
+ fold = 16
333
+
334
+ if logger:
335
+ logger.info(
336
+ "Inspecting maximum {} timestamps after '{}':".format(
337
+ chunk_size * fold, str(last_uts)
338
+ )
339
+ )
340
+ df0 = get_new_updated_timestamps(
341
+ mutable_name, last_uts, limit=chunk_size * fold
342
+ )
343
+ if len(df0) == 0:
344
+ break
345
+
346
+ # determine new fold
347
+ count_sum = df0["count"].sum()
348
+ fold = (count_sum // chunk_size) + 1
349
+ if logger:
350
+ logger.info(
351
+ " Found {} records to be downloaded in {} chunks.".format(
352
+ count_sum, fold
353
+ )
354
+ )
355
+
356
+ df0_cnt = len(df0)
357
+ for i in range(fold):
358
+ start_ofs = df0_cnt * i // fold
359
+ end_ofs = (df0_cnt * (i + 1)) // fold
360
+
361
+ df = df0.iloc[start_ofs:end_ofs]
362
+
363
+ first_uts = last_uts
364
+ last_uts = df[updated_col].max()
365
+ chunk_cnt = df["count"].sum()
366
+
367
+ if logger:
368
+ logger.info(
369
+ " {}: Downloading {} records timestamped in ['{}','{}'].".format(
370
+ i + 1, chunk_cnt, str(first_uts), str(last_uts)
371
+ )
372
+ )
373
+ query_str = "SELECT * FROM {frame_str} WHERE {updated_col}>'{first_uts}' and {updated_col}<='{last_uts}';".format(
374
+ frame_str=frame_str,
375
+ updated_col=updated_col,
376
+ first_uts=str(first_uts),
377
+ last_uts=str(last_uts),
378
+ )
379
+ df = ss.read_sql(query_str, remote_engine)
380
+
381
+ count += update(df, mutable_name, index_col, is_new_table)
382
+
383
+ if logger:
384
+ logger.info(
385
+ "Downloaded {} records and updated timestamp to '{}'.".format(
386
+ count, str(last_uts)
387
+ )
388
+ )
389
+
390
+
391
+ def merge_from(mutable_name: str, other_engine, logger=None):
392
+ """Merges a mutable from another sqlite database to this database.
393
+
394
+ Parameters
395
+ ----------
396
+ mutable_name : str
397
+ mutable to merge. Must be a key of `mutable_map` module variable.
398
+ other_engine : sqlalchemy.engine.Engine
399
+ connection engine to an sqlite3 wml database that contains the mutable to merge from
400
+ logger : mt.logg.IndentedLoggerAdapter, optional
401
+ logger for debugging purposes
402
+
403
+ Returns
404
+ -------
405
+ bool
406
+ whether the table has been successfully merged
407
+ """
408
+
409
+ if not mutable_name in mutable_map:
410
+ raise ValueError(
411
+ "First argument must be a valid mutable name in the 'mutable_map'. Got: '{}'.".format(
412
+ mutable_name
413
+ )
414
+ )
415
+ t = mutable_map[mutable_name]
416
+ if t.dst_engine != "wml":
417
+ logger.info(f"Skipping mutable '{mutable_name}' as it is not in wml db.")
418
+ return False
419
+
420
+ msg = "Merging mutable '{}' from '{}'".format(mutable_name, other_engine.url)
421
+ with logg.scoped_info(msg, logger):
422
+ if not mutable_name in ss.list_tables(other_engine):
423
+ if logger:
424
+ logger.info("Remote mutable does not exist.")
425
+ return False
426
+
427
+ # merge structure
428
+ remote_sql_str = ss.get_table_sql_code(mutable_name, other_engine)
429
+ if mutable_name in ss.list_tables(wml_engine): # local table exists
430
+ local_sql_str = ss.get_table_sql_code(mutable_name, wml_engine)
431
+ if local_sql_str != remote_sql_str:
432
+ if logger:
433
+ logger.debug(
434
+ "Remote table has a different sql definition from the local one:"
435
+ )
436
+ logger.debug(" Local : {}".format(local_sql_str))
437
+ logger.debug(" Remote: {}".format(remote_sql_str))
438
+ logger.info("Skipping.")
439
+ return False
440
+ else: # local table does not exist
441
+ if logger:
442
+ logger.info("Creating local mutable.")
443
+ logger.debug(" SQL: {}".format(remote_sql_str))
444
+ ss.engine_execute(wml_engine, remote_sql_str)
445
+
446
+ # merge index
447
+ remote_indices = ss.list_indices(other_engine)
448
+ remote_indices = getattr(remote_indices, mutable_name, {})
449
+ local_indices = ss.list_indices(wml_engine)
450
+ local_indices = getattr(local_indices, mutable_name, {})
451
+ for index_col in remote_indices:
452
+ if index_col in local_indices:
453
+ continue
454
+ index_sql_str = remote_indices[index_col]
455
+ if logger:
456
+ logger.info("Creating index '{}'.".format(index_col))
457
+ logger.debug(" SQL: {}".format(index_sql_str))
458
+ ss.engine_execute(wml_engine, index_sql_str)
459
+
460
+ # identify duplicate records
461
+ t = mutable_map[mutable_name]
462
+ with logg.scoped_info("Comparing local <-> remote headers", logger=logger):
463
+ if t.updated_col is None:
464
+ if logger:
465
+ logger.info("Loading local ids:")
466
+ query_str = "SELECT {index_col} FROM {mutable_name}".format(
467
+ index_col=t.index_col, mutable_name=mutable_name
468
+ )
469
+ local_id_df = ss.read_sql(query_str, wml_engine)
470
+ local_id_list = local_id_df[index_col].tolist()
471
+ if logger:
472
+ logger.debug(" {} ids loaded.".format(len(local_id_df)))
473
+ logger.info("Loading remote ids:")
474
+
475
+ query_str = "SELECT {index_col} FROM {mutable_name}".format(
476
+ index_col=t.index_col, mutable_name=mutable_name
477
+ )
478
+ remote_id_df = ss.read_sql(query_str, other_engine)
479
+ remote_id_list = remote_id_df[t.index_col].tolist()
480
+ if logger:
481
+ logger.debug(" {} ids loaded.".format(len(remote_id_df)))
482
+
483
+ local_set = set(local_id_list)
484
+ remote_set = set(remote_id_list)
485
+ drop_list = []
486
+ insert_list = list(remote_set - local_set)
487
+
488
+ if logger:
489
+ logger.info(
490
+ "{} remote records to be inserted.".format(len(insert_list))
491
+ )
492
+ else:
493
+ if logger:
494
+ logger.info("Loading local headers:")
495
+ query_str = "SELECT {index_col}, {updated_col} AS local_updated FROM {mutable_name}".format(
496
+ index_col=t.index_col,
497
+ updated_col=t.updated_col,
498
+ mutable_name=mutable_name,
499
+ )
500
+ local_id_df = ss.read_sql(query_str, wml_engine, index_col=t.index_col)
501
+ if logger:
502
+ logger.debug(" {} headers loaded.".format(len(local_id_df)))
503
+ logger.info("Loading remote headers:")
504
+
505
+ query_str = "SELECT {index_col}, {updated_col} AS remote_updated FROM {mutable_name}".format(
506
+ index_col=t.index_col,
507
+ updated_col=t.updated_col,
508
+ mutable_name=mutable_name,
509
+ )
510
+ remote_id_df = ss.read_sql(
511
+ query_str, other_engine, index_col=t.index_col
512
+ )
513
+ if logger:
514
+ logger.debug(" {} headers loaded.".format(len(remote_id_df)))
515
+
516
+ # records in local but not in remote
517
+ id_df = local_id_df.join(remote_id_df, how="outer").reset_index(
518
+ drop=False
519
+ )
520
+ drop_list = []
521
+ insert_list = []
522
+ the_list = id_df[id_df["remote_updated"].isnull()][t.index_col].tolist()
523
+ if logger:
524
+ logger.debug(
525
+ "{} local records not in remote.".format(len(the_list))
526
+ )
527
+
528
+ # records in remote but not in local
529
+ id_df = id_df[id_df["remote_updated"].notnull()]
530
+ the_list = id_df[id_df["local_updated"].isnull()][t.index_col].tolist()
531
+ insert_list.extend(the_list)
532
+ if logger:
533
+ logger.debug(
534
+ "{} remote records not in local.".format(len(the_list))
535
+ )
536
+
537
+ # duplicate records
538
+ id_df = id_df[id_df["local_updated"].notnull()]
539
+ common_list = id_df[id_df["local_updated"] == id_df["remote_updated"]][
540
+ t.index_col
541
+ ].tolist()
542
+ if logger:
543
+ logger.info(
544
+ "{} common records to be kept.".format(len(common_list))
545
+ )
546
+
547
+ # new local records
548
+ id_df = id_df[id_df["local_updated"] != id_df["remote_updated"]]
549
+ the_list = id_df[id_df["local_updated"] > id_df["remote_updated"]][
550
+ t.index_col
551
+ ].tolist()
552
+ if logger:
553
+ logger.debug("{} new local records detected.".format(len(the_list)))
554
+
555
+ # new remote records
556
+ id_df = id_df[id_df["local_updated"] < id_df["remote_updated"]]
557
+ the_list = id_df[t.index_col].tolist()
558
+ drop_list.extend(the_list)
559
+ insert_list.extend(the_list)
560
+ if logger:
561
+ logger.debug(
562
+ "{} new remote records detected.".format(len(the_list))
563
+ )
564
+ logger.info(
565
+ "{} local records to be dropped.".format(len(drop_list))
566
+ )
567
+ logger.info(
568
+ "{} remote records to be inserted.".format(len(insert_list))
569
+ )
570
+
571
+ # drop records
572
+ if drop_list:
573
+ if logger:
574
+ logger.info("Dropping {} local records.".format(len(drop_list)))
575
+ query_str = ",".join((str(x) for x in drop_list))
576
+ query_str = (
577
+ "DELETE FROM {mutable_name} WHERE {index_col} IN ({values});".format(
578
+ mutable_name=mutable_name, index_col=t.index_col, values=query_str
579
+ )
580
+ )
581
+ ss.engine_execute(wml_engine, query_str)
582
+
583
+ # insert remote records
584
+ if insert_list:
585
+ msg = "Upserting {} remote records".format(len(insert_list))
586
+ with logg.scoped_info(msg, logger=logger):
587
+ while insert_list:
588
+ the_list = insert_list[:1000000]
589
+ insert_list = insert_list[1000000:]
590
+
591
+ query_str = ",".join((str(x) for x in the_list))
592
+ query_str = "SELECT * FROM {mutable_name} WHERE {index_col} IN ({values});".format(
593
+ mutable_name=mutable_name,
594
+ index_col=t.index_col,
595
+ values=query_str,
596
+ )
597
+ if logger:
598
+ logger.info(
599
+ "Loading {} remote records, {} remaining.".format(
600
+ len(the_list), len(insert_list)
601
+ )
602
+ )
603
+ df = ss.read_sql(query_str, other_engine, logger=logger)
604
+ if logger:
605
+ logger.info("Inserting {} new records.".format(len(df)))
606
+ _to_sql(
607
+ df,
608
+ mutable_name,
609
+ wml_engine,
610
+ index=False,
611
+ if_exists="append",
612
+ dtype=t.dtype,
613
+ from_mysql=t.schema == "winnow_db",
614
+ )
615
+
616
+ if logger:
617
+ logger.info("Merge completed.")
618
+
619
+ return True
620
+
621
+
622
+ def vacuum(level: str = "full", logger=None):
623
+ """Cleans up duplicate records in mutables and then make the wml database compact.
624
+
625
+ Parameters
626
+ ----------
627
+ level : {'full', 'minimal'}
628
+ If 'minimal', only the mutables are vacuumed. Otherwise, in addition to 'minima', the whole
629
+ wml database is vacuumed.
630
+ logger : mt.logg.IndentedLoggerAdapter, optional
631
+ logger for debugging purposes
632
+ """
633
+
634
+ wml_tables = ss.list_tables(wml_engine)
635
+
636
+ for mutable_name, t in mutable_map.items():
637
+ if (mutable_name not in wml_tables) or (t.dst_engine != "wml"):
638
+ continue
639
+
640
+ msg = "Vacuuming mutable '{}'".format(mutable_name)
641
+ with logg.scoped_info(msg, logger=logger):
642
+ schema = t.schema
643
+ table = t.table
644
+ index_col = t.index_col
645
+ chunk_size = t.chunk_size
646
+ updated_col = t.updated_col
647
+ query_str = """WITH t1 AS (SELECT {index_col}, COUNT(*) AS id_cnt FROM {mutable_name} GROUP BY {index_col})
648
+ SELECT {mutable_name}.* FROM {mutable_name}
649
+ LEFT JOIN t1 ON {mutable_name}.{index_col}=t1.{index_col}
650
+ WHERE t1.id_cnt > 1
651
+ ;""".format(
652
+ mutable_name=mutable_name,
653
+ index_col=index_col,
654
+ )
655
+ df = ss.read_sql(query_str, wml_engine, logger=logger)
656
+ if len(df) == 0:
657
+ if logger:
658
+ logger.info("The table is clean.")
659
+ continue
660
+
661
+ id_list = df[index_col].drop_duplicates().tolist()
662
+ if logger:
663
+ logger.info(
664
+ "Detected {} duplicate ids spanning over {} records.".format(
665
+ len(id_list), len(df)
666
+ )
667
+ )
668
+
669
+ query_str = ",".join((str(x) for x in id_list))
670
+ query_str = (
671
+ "DELETE FROM {mutable_name} WHERE {index_col} IN ({values});".format(
672
+ mutable_name=mutable_name,
673
+ index_col=index_col,
674
+ values=query_str,
675
+ )
676
+ )
677
+ ss.engine_execute(wml_engine, query_str)
678
+ if logger:
679
+ logger.info("Deleted {} duplicate records.".format(len(df)))
680
+
681
+ if updated_col is None:
682
+ df = df.groupby(index_col).head(1)
683
+ else:
684
+ df = (
685
+ df.sort_values([index_col, updated_col], ascending=[True, False])
686
+ .groupby(index_col)
687
+ .head(1)
688
+ )
689
+ df = df.set_index(index_col, drop=True)
690
+ _to_sql(
691
+ df,
692
+ mutable_name,
693
+ wml_engine,
694
+ if_exists="append",
695
+ dtype=t.dtype,
696
+ from_mysql=t.schema == "winnow_db",
697
+ )
698
+ if logger:
699
+ logger.info("Reinserted {} clean records.".format(len(df)))
700
+
701
+ if level == "full":
702
+ if logger:
703
+ logger.info("Vacuuming the wml database.")
704
+ ss.vacuum(wml_engine)
705
+
706
+ if logger:
707
+ logger.info("Done.")
708
+
709
+ '''
mb/sql/utils.py ADDED
@@ -0,0 +1,148 @@
1
+ ## file for extra sql functions
2
+
3
+ import pandas as pd
4
+ from .sql import read_sql,engine_execute
5
+ import os
6
+ from mb.utils.logging import logg
7
+
8
+ __all__ = ['list_schemas','rename_table','drop_table','drop_schema','create_schema','create_index','clone_db']
9
+
10
+ def list_schemas(engine,logger=None) -> pd.DataFrame:
11
+ """
12
+ Returns list of schemas in database.
13
+
14
+ Args:
15
+ engine (sqlalchemy.engine.base.Engine): Engine object.
16
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
17
+ Returns:
18
+ df (pandas.core.frame.DataFrame): DataFrame object.
19
+ """
20
+ q1 ="SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('information_schema', 'mysql', 'performance_schema') ORDER BY schema_name;"
21
+ return read_sql(q1,engine,logger=logger)
22
+
23
+ def list_tables(engine,schema=None,logger=None) -> pd.DataFrame:
24
+ """
25
+ Returns list of tables in database.
26
+
27
+ Args:
28
+ engine (sqlalchemy.engine.base.Engine): Engine object.
29
+ schema (str): Name of the schema. Default: None
30
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
31
+ Returns:
32
+ df (pandas.core.frame.DataFrame): DataFrame object.
33
+
34
+ """
35
+ q1 = f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema}' ORDER BY table_name;"
36
+ return read_sql(q1,engine,logger=logger)
37
+
38
+
39
+ def rename_table(new_table_name,old_table_name,engine,schema_name=None,logger=None):
40
+ """
41
+ Rename table in database.
42
+
43
+ Args:
44
+ new_table_name (str): New name of the table.
45
+ old_table_name (str): Old name of the table.
46
+ engine (sqlalchemy.engine.base.Engine): Engine object.
47
+ schema_name (str): Name of the schema. Default: None
48
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
49
+ Returns:
50
+ None
51
+ """
52
+ if schema_name:
53
+ q1 = f"ALTER TABLE {schema_name}.{old_table_name} RENAME TO {new_table_name};"
54
+ else:
55
+ q1 = f"ALTER TABLE {old_table_name} RENAME TO {new_table_name};"
56
+ engine_execute(engine,q1)
57
+ logg.info(f'Table {old_table_name} renamed to {new_table_name}.',logger=logger)
58
+
59
+ def drop_table(table_name,engine,schema_name=None,logger=None):
60
+ """
61
+ Drop table in database.
62
+
63
+ Args:
64
+ table_name (str): Name of the table.
65
+ engine (sqlalchemy.engine.base.Engine): Engine object.
66
+ schema_name (str): Name of the schema. Default: None
67
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
68
+ Returns:
69
+ None
70
+ """
71
+ if schema_name:
72
+ q1 = f"DROP TABLE {schema_name}.{table_name};"
73
+ else:
74
+ q1 = f"DROP TABLE {table_name};"
75
+ engine_execute(engine,q1)
76
+ logg.info(f'Table {table_name} dropped.',logger=logger)
77
+
78
+ def drop_schema(schema_name,engine,logger=None):
79
+ """
80
+ Drop schema in database.
81
+
82
+ Args:
83
+ schema_name (str): Name of the schema.
84
+ engine (sqlalchemy.engine.base.Engine): Engine object.
85
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
86
+ Returns:
87
+ None
88
+ """
89
+ q1 = f"DROP SCHEMA {schema_name};"
90
+ engine_execute(engine,q1)
91
+ logg.info(f'Schema {schema_name} dropped.',logger=logger)
92
+
93
+ def create_schema(schema_name,engine,logger=None):
94
+ """
95
+ Create schema in database.
96
+
97
+ Args:
98
+ schema_name (str): Name of the schema.
99
+ engine (sqlalchemy.engine.base.Engine): Engine object.
100
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
101
+ Returns:
102
+ None
103
+ """
104
+ q1 = f"CREATE SCHEMA {schema_name};"
105
+ engine_execute(engine,q1)
106
+ logg.info(f'Schema {schema_name} created.',logger=logger)
107
+
108
+ def create_index(table,index_col,engine,logger=None):
109
+ """
110
+ Create an Index for a table in database.
111
+
112
+ Args:
113
+ table (str): Name of the table.
114
+ index_col (str): Name of the index_col in the table.
115
+ engine (sqlalchemy.engine.base.Engine): Engine object.
116
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
117
+ Returns:
118
+ None
119
+ """
120
+ q1 = f"CREATE INDEX {index_col} ON {table};"
121
+ engine_execute(engine,q1)
122
+ logg.info(f'Index {index_col} created for table {table}.',logger=logger)
123
+
124
+ def clone_db(ori_db_location,copy_db_location, logger=None):
125
+ """
126
+ Clone a database.
127
+
128
+ Args:
129
+ ori_db_location (str): Location of the original database.
130
+ copy_db_location (str): Location for the new database.
131
+ logger (logging.Logger): Logger object. Default: mb_utils.src.logging.logger
132
+ Returns:
133
+ None
134
+ """
135
+ if not os.path.exists(ori_db_location):
136
+ raise FileNotFoundError("The original location does not exist.")
137
+
138
+ if os.path.exists(copy_db_location):
139
+ copy_db_location = os.path.join(copy_db_location,'copy_db')
140
+
141
+ if os.path.exists(copy_db_location)==False:
142
+ os.makedirs(copy_db_location)
143
+
144
+ cmd = f"pg_dump -U postgres -h localhost -p 5432 {ori_db_location} | psql -U postgres -h localhost -p 5432 {copy_db_location}"
145
+ os.system(cmd)
146
+
147
+ logg.info(f'Database cloned from {ori_db_location} to {copy_db_location}.',logger=logger)
148
+
mb/sql/version.py ADDED
@@ -0,0 +1,5 @@
1
+ MAJOR_VERSION = 2
2
+ MINOR_VERSION = 0
3
+ PATCH_VERSION = 1
4
+ version = '{}.{}.{}'.format(MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION)
5
+ __all__ = ['MAJOR_VERSION', 'MINOR_VERSION', 'PATCH_VERSION', 'version']
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.4
2
+ Name: mb_sql
3
+ Version: 2.0.1
4
+ Summary: mb SQL functions API
5
+ Author: ['Malav Bateriwala']
6
+ Requires-Python: >=3.8
7
+ Requires-Dist: mb_utils
8
+ Dynamic: author
9
+ Dynamic: requires-dist
10
+ Dynamic: requires-python
11
+ Dynamic: summary
@@ -0,0 +1,13 @@
1
+ mb/sql/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ mb/sql/conn.py,sha256=HTyAUDkBlp_yvDRNRS9OoTRJPvrcS2SEvnx46AXK3OA,3210
3
+ mb/sql/slack.py,sha256=3dWbKvhnp1sIEWGk3m43WnfXbtW56Ohk41_5puZnH-A,549
4
+ mb/sql/snapshot.py,sha256=3C7vqRE3vUwyEEQQIsSIfz2JMjli2hV5Tm4fONbgEbo,4775
5
+ mb/sql/sql.py,sha256=cj00ta1j53o7VRf3LeXj9uGE8p1IcUF7CUFbWrVM0AA,3655
6
+ mb/sql/tables.py,sha256=xECYrXJKvdGvCPZo9vT1o7-HO0LoFjaD8602BQo4oO4,1669
7
+ mb/sql/update.py,sha256=Ed9MA3y9ZJ3arwSq9cbIPBe1hI5PCqmPyY6rD5huq64,26741
8
+ mb/sql/utils.py,sha256=Bh1eadnfnckeDJg25rhnrDXJ3YVN8gjp2AedgyM8Iww,5248
9
+ mb/sql/version.py,sha256=JJiVhzInvyc4XUc31yQR_q4UNm6rnom-oatXO_VvCfo,206
10
+ mb_sql-2.0.1.dist-info/METADATA,sha256=tOr2lC_YFhpIcrBirbHz_tsq7DsmlTvsaX8AkmRdLzM,237
11
+ mb_sql-2.0.1.dist-info/WHEEL,sha256=YCfwYGOYMi5Jhw2fU4yNgwErybb2IX5PEwBKV4ZbdBo,91
12
+ mb_sql-2.0.1.dist-info/top_level.txt,sha256=2T5lqIVZs7HUr0KqUMPEaPF59QXcr0ErDy6K-J88ycM,3
13
+ mb_sql-2.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ mb