ry-pg-utils 1.0.6__tar.gz → 1.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ry_pg_utils-1.0.6/src/ry_pg_utils.egg-info → ry_pg_utils-1.0.8}/PKG-INFO +1 -1
- ry_pg_utils-1.0.8/VERSION +1 -0
- ry_pg_utils-1.0.8/src/ry_pg_utils/tools/__init__.py +0 -0
- ry_pg_utils-1.0.8/src/ry_pg_utils/tools/db_query.py +511 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8/src/ry_pg_utils.egg-info}/PKG-INFO +1 -1
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils.egg-info/SOURCES.txt +3 -1
- ry_pg_utils-1.0.6/VERSION +0 -1
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/LICENSE +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/MANIFEST.in +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/README.md +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/packages/base_requirements.in +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/pyproject.toml +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/setup.cfg +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/setup.py +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/__init__.py +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/config.py +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/connect.py +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/dynamic_table.py +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/ipc/__init__.py +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/ipc/channels.py +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/notify_trigger.py +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/parse_args.py +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/pb_types/__init__.py +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/pb_types/database_pb2.py +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/pb_types/database_pb2.pyi +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/pb_types/py.typed +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/postgres_info.py +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/py.typed +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils/updater.py +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils.egg-info/dependency_links.txt +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils.egg-info/requires.txt +0 -0
- {ry_pg_utils-1.0.6 → ry_pg_utils-1.0.8}/src/ry_pg_utils.egg-info/top_level.txt +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
1.0.8
|
File without changes
|
@@ -0,0 +1,511 @@
|
|
1
|
+
import abc
|
2
|
+
import concurrent.futures
|
3
|
+
import os
|
4
|
+
import subprocess
|
5
|
+
import time
|
6
|
+
import typing as T
|
7
|
+
|
8
|
+
import pandas as pd
|
9
|
+
import paramiko
|
10
|
+
import psycopg2
|
11
|
+
from pyspark.sql import SparkSession
|
12
|
+
from ryutils import log, modern_ssh_tunnel
|
13
|
+
|
14
|
+
from ry_pg_utils import config
|
15
|
+
|
16
|
+
|
17
|
+
class DbQuery(abc.ABC):
|
18
|
+
DB_STALE_MINS = 15
|
19
|
+
LOCAL_HOST = "127.0.0.1"
|
20
|
+
|
21
|
+
# pylint: disable=too-many-arguments
|
22
|
+
# pylint: disable=too-many-positional-arguments
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
postgres_host: str | None = None,
|
26
|
+
postgres_port: int | None = None,
|
27
|
+
postgres_database: str | None = None,
|
28
|
+
postgres_user: str | None = None,
|
29
|
+
postgres_password: str | None = None,
|
30
|
+
postrgres_url: str | None = None,
|
31
|
+
ssh_host: str | None = None,
|
32
|
+
ssh_port: int | None = None,
|
33
|
+
ssh_user: str | None = None,
|
34
|
+
ssh_pkey: str | None = None,
|
35
|
+
is_local: bool = False,
|
36
|
+
verbose: bool = False,
|
37
|
+
):
|
38
|
+
self.is_local = is_local
|
39
|
+
|
40
|
+
self.postgres_host = (
|
41
|
+
postgres_host if postgres_host is not None else config.pg_config.postgres_host
|
42
|
+
)
|
43
|
+
self.postgres_port = (
|
44
|
+
postgres_port if postgres_port is not None else config.pg_config.postgres_port
|
45
|
+
)
|
46
|
+
self.postgres_database = (
|
47
|
+
postgres_database if postgres_database is not None else config.pg_config.postgres_db
|
48
|
+
)
|
49
|
+
self.postgres_user = (
|
50
|
+
postgres_user if postgres_user is not None else config.pg_config.postgres_user
|
51
|
+
)
|
52
|
+
self.postgres_password = (
|
53
|
+
postgres_password
|
54
|
+
if postgres_password is not None
|
55
|
+
else config.pg_config.postgres_password
|
56
|
+
)
|
57
|
+
self.postgres_uri = (
|
58
|
+
f"postgresql://{self.postgres_user}:{self.postgres_password}@"
|
59
|
+
f"{self.postgres_host}:{self.postgres_port}/{self.postgres_database}"
|
60
|
+
)
|
61
|
+
self.postrgres_url = postrgres_url if postrgres_url is not None else self.postgres_uri
|
62
|
+
|
63
|
+
self.ssh_host = ssh_host if ssh_host is not None else config.pg_config.ssh_host
|
64
|
+
self.ssh_port = ssh_port if ssh_port is not None else config.pg_config.ssh_port
|
65
|
+
self.ssh_user = ssh_user if ssh_user is not None else config.pg_config.ssh_user
|
66
|
+
self.ssh_pkey = ssh_pkey if ssh_pkey is not None else config.pg_config.ssh_key_path
|
67
|
+
|
68
|
+
self.db_name = f"temp_{self.postgres_database}" if is_local else self.postgres_database
|
69
|
+
|
70
|
+
if verbose:
|
71
|
+
log.print_normal(ssh_host, ssh_port, ssh_user, ssh_pkey)
|
72
|
+
log.print_normal(
|
73
|
+
postgres_host,
|
74
|
+
postgres_port,
|
75
|
+
postgres_database,
|
76
|
+
postgres_user,
|
77
|
+
postgres_password,
|
78
|
+
postrgres_url,
|
79
|
+
ssh_host,
|
80
|
+
ssh_port,
|
81
|
+
ssh_user,
|
82
|
+
ssh_pkey,
|
83
|
+
)
|
84
|
+
|
85
|
+
self.ssh_tunnel: modern_ssh_tunnel.SSHTunnelForwarder | None = None
|
86
|
+
self.conn: psycopg2.extensions.connection | SparkSession | None = None
|
87
|
+
|
88
|
+
@staticmethod
|
89
|
+
def db(table_name: str) -> str:
|
90
|
+
return f'"public"."{table_name}"'
|
91
|
+
|
92
|
+
def _maybe_copy_database_locally(self, local_db_path: str) -> None:
|
93
|
+
# Define the path to the local copy of the database
|
94
|
+
|
95
|
+
# Check if the local database file was modified in the last 5 minutes
|
96
|
+
if os.path.exists(local_db_path):
|
97
|
+
last_modified_time = os.path.getmtime(local_db_path)
|
98
|
+
current_time = time.time()
|
99
|
+
if current_time - last_modified_time < 60.0 * self.DB_STALE_MINS:
|
100
|
+
log.print_normal(
|
101
|
+
f"Local database copy was modified in the last "
|
102
|
+
f"{self.DB_STALE_MINS} minutes. Skipping copy."
|
103
|
+
)
|
104
|
+
return
|
105
|
+
|
106
|
+
# Copy the database locally
|
107
|
+
log.print_normal("Copying database locally...")
|
108
|
+
|
109
|
+
# Validate required parameters
|
110
|
+
assert self.ssh_host is not None, "SSH host is required"
|
111
|
+
assert self.ssh_port is not None, "SSH port is required"
|
112
|
+
assert self.ssh_user is not None, "SSH user is required"
|
113
|
+
assert self.ssh_pkey is not None, "SSH private key path is required"
|
114
|
+
|
115
|
+
client = paramiko.SSHClient()
|
116
|
+
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
117
|
+
client.connect(
|
118
|
+
self.ssh_host,
|
119
|
+
port=self.ssh_port,
|
120
|
+
username=self.ssh_user,
|
121
|
+
key_filename=self.ssh_pkey,
|
122
|
+
)
|
123
|
+
sftp = client.open_sftp()
|
124
|
+
|
125
|
+
remote_temp_file_path = "/tmp/backup_file"
|
126
|
+
command = (
|
127
|
+
f"PGPASSWORD={self.postgres_password} pg_dump -U {self.postgres_user} "
|
128
|
+
f"-h {self.postgres_host} -p {self.postgres_port} -F c -b -v -f "
|
129
|
+
f"{remote_temp_file_path} {self.postgres_database}"
|
130
|
+
)
|
131
|
+
log.print_normal(f"Running command: PGPASSWORD=**** {' '.join(command.split()[1:])}")
|
132
|
+
|
133
|
+
# Run pg_dump on the remote server
|
134
|
+
_, stdout, stderr = client.exec_command(command)
|
135
|
+
log.print_normal(stdout.read().decode())
|
136
|
+
log.print_normal(stderr.read().decode())
|
137
|
+
|
138
|
+
# SFTP setup
|
139
|
+
sftp.get(remote_temp_file_path, local_db_path)
|
140
|
+
sftp.close()
|
141
|
+
client.close()
|
142
|
+
|
143
|
+
def _import_local_database(self, local_db_path: str, temp_db_name: str) -> None:
|
144
|
+
assert self.postgres_password is not None, "Postgres password is required"
|
145
|
+
os.environ["PGPASSWORD"] = self.postgres_password
|
146
|
+
|
147
|
+
# Check if the temporary database already exists
|
148
|
+
check_db_exists_command = (
|
149
|
+
f"psql -U {self.postgres_user} -h {self.postgres_host} "
|
150
|
+
f"-p {self.postgres_port} -lqt | cut -d \\| -f 1 | grep -w {temp_db_name}"
|
151
|
+
)
|
152
|
+
log.print_normal(
|
153
|
+
f"Checking if database {temp_db_name} already exists...\n{check_db_exists_command}"
|
154
|
+
)
|
155
|
+
result = subprocess.run(
|
156
|
+
check_db_exists_command, shell=True, check=True, stdout=subprocess.PIPE
|
157
|
+
)
|
158
|
+
|
159
|
+
drop_db_command = (
|
160
|
+
f"psql -U {self.postgres_user} -h {self.postgres_host} "
|
161
|
+
f"-p {self.postgres_port} -c 'DROP DATABASE {temp_db_name}'"
|
162
|
+
)
|
163
|
+
if result.stdout or result.stderr:
|
164
|
+
log.print_normal(f"Database {temp_db_name} already exists. Dropping it first...")
|
165
|
+
subprocess.run(drop_db_command, shell=True, check=True)
|
166
|
+
|
167
|
+
log.print_normal("Creating temporary database...")
|
168
|
+
create_db_command = (
|
169
|
+
f"psql -U {self.postgres_user} -h {self.postgres_host} "
|
170
|
+
f"-p {self.postgres_port} -c 'CREATE DATABASE {temp_db_name}'"
|
171
|
+
)
|
172
|
+
try:
|
173
|
+
subprocess.run(drop_db_command, shell=True, check=True)
|
174
|
+
except Exception: # pylint: disable=broad-except
|
175
|
+
pass
|
176
|
+
|
177
|
+
subprocess.run(create_db_command, shell=True, check=True)
|
178
|
+
|
179
|
+
log.print_normal("Restoring database into the temporary database...")
|
180
|
+
|
181
|
+
# # Restore the database from the binary dump file using pg_restore
|
182
|
+
restore_db_command = (
|
183
|
+
f"pg_restore -U {self.postgres_user} -h {self.postgres_host} "
|
184
|
+
f"-p {self.postgres_port} -d {temp_db_name} -v {local_db_path}"
|
185
|
+
)
|
186
|
+
subprocess.run(restore_db_command, shell=True, check=True)
|
187
|
+
|
188
|
+
log.print_normal("Database import complete!")
|
189
|
+
|
190
|
+
log.print_normal(f"Database imported into temporary database {temp_db_name}")
|
191
|
+
|
192
|
+
@abc.abstractmethod
|
193
|
+
def connect(self, use_ssh_tunnel: bool = False) -> None:
|
194
|
+
pass
|
195
|
+
|
196
|
+
@abc.abstractmethod
|
197
|
+
def load_tables(self, tables: T.List[str]) -> T.Dict[str, pd.DataFrame]:
|
198
|
+
pass
|
199
|
+
|
200
|
+
def _establish_ssh_tunnel(self) -> None:
|
201
|
+
log.print_normal(
|
202
|
+
"Establishing SSH tunnel: ",
|
203
|
+
self.ssh_host,
|
204
|
+
self.ssh_port,
|
205
|
+
self.ssh_user,
|
206
|
+
self.ssh_pkey,
|
207
|
+
)
|
208
|
+
|
209
|
+
# Validate required parameters
|
210
|
+
assert self.ssh_host is not None, "SSH host is required"
|
211
|
+
assert self.ssh_port is not None, "SSH port is required"
|
212
|
+
assert self.ssh_user is not None, "SSH user is required"
|
213
|
+
assert self.ssh_pkey is not None, "SSH private key path is required"
|
214
|
+
assert self.postgres_port is not None, "Postgres port is required"
|
215
|
+
|
216
|
+
self.ssh_tunnel = modern_ssh_tunnel.SSHTunnelForwarder(
|
217
|
+
(self.ssh_host, self.ssh_port),
|
218
|
+
ssh_username=self.ssh_user,
|
219
|
+
ssh_pkey=self.ssh_pkey,
|
220
|
+
remote_bind_address=(self.LOCAL_HOST, self.postgres_port),
|
221
|
+
)
|
222
|
+
|
223
|
+
self.ssh_tunnel.start()
|
224
|
+
log.print_ok_arrow(f"SSH tunnel active: {self.ssh_tunnel.is_active}")
|
225
|
+
|
226
|
+
def query(self, query: str, verbose: bool = False) -> pd.DataFrame:
|
227
|
+
start_time = time.time()
|
228
|
+
|
229
|
+
if verbose:
|
230
|
+
log.print_normal("=" * 80)
|
231
|
+
log.print_normal(query)
|
232
|
+
log.print_normal("=" * 80)
|
233
|
+
|
234
|
+
try:
|
235
|
+
df = pd.read_sql_query(query, self.conn) # type: ignore
|
236
|
+
time_delta = time.time() - start_time
|
237
|
+
if verbose:
|
238
|
+
log.print_ok_arrow(f"Time taken to query database: {time_delta:.2f} seconds")
|
239
|
+
return T.cast(pd.DataFrame, df)
|
240
|
+
except Exception as e: # pylint: disable=broad-except
|
241
|
+
log.print_normal(f"Error executing query: {e}")
|
242
|
+
|
243
|
+
return pd.DataFrame()
|
244
|
+
|
245
|
+
@abc.abstractmethod
|
246
|
+
def clear(self, table: str) -> None:
|
247
|
+
pass
|
248
|
+
|
249
|
+
def close(self) -> None:
|
250
|
+
if isinstance(self.conn, SparkSession):
|
251
|
+
self.conn.stop()
|
252
|
+
elif isinstance(self.conn, psycopg2.extensions.connection):
|
253
|
+
self.conn.close()
|
254
|
+
if self.ssh_tunnel is not None and self.ssh_tunnel.is_active:
|
255
|
+
self.ssh_tunnel.stop()
|
256
|
+
log.print_ok_arrow("SSH tunnel closed successfully.")
|
257
|
+
|
258
|
+
self.ssh_tunnel = None
|
259
|
+
self.conn = None
|
260
|
+
|
261
|
+
log.print_ok_arrow("Connection to the database closed successfully.")
|
262
|
+
|
263
|
+
def copy_db_local(self, local_db_path: str) -> None:
|
264
|
+
assert self.db_name is not None, "Database name is required"
|
265
|
+
self._maybe_copy_database_locally(local_db_path=local_db_path)
|
266
|
+
self._import_local_database(local_db_path=local_db_path, temp_db_name=self.db_name)
|
267
|
+
|
268
|
+
def run_command(self, command: str) -> None:
|
269
|
+
"""Run a command on the remote server via direct SSH connection."""
|
270
|
+
# Validate required parameters
|
271
|
+
assert self.ssh_host is not None, "SSH host is required"
|
272
|
+
assert self.ssh_port is not None, "SSH port is required"
|
273
|
+
assert self.ssh_user is not None, "SSH user is required"
|
274
|
+
assert self.ssh_pkey is not None, "SSH private key path is required"
|
275
|
+
|
276
|
+
client = paramiko.SSHClient()
|
277
|
+
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
278
|
+
|
279
|
+
try:
|
280
|
+
client.connect(
|
281
|
+
hostname=self.ssh_host,
|
282
|
+
port=self.ssh_port,
|
283
|
+
username=self.ssh_user,
|
284
|
+
key_filename=self.ssh_pkey,
|
285
|
+
)
|
286
|
+
|
287
|
+
log.print_normal("Running command on remote server:")
|
288
|
+
log.print_normal("=" * 80)
|
289
|
+
log.print_normal(command)
|
290
|
+
log.print_normal("=" * 80)
|
291
|
+
|
292
|
+
_, stdout, stderr = client.exec_command(command)
|
293
|
+
log.print_normal(stdout.read().decode())
|
294
|
+
log.print_normal(stderr.read().decode())
|
295
|
+
|
296
|
+
except Exception as e: # pylint: disable=broad-except
|
297
|
+
log.print_fail(f"Failed to run command: {e}")
|
298
|
+
finally:
|
299
|
+
client.close()
|
300
|
+
|
301
|
+
|
302
|
+
class DbQueryPsycopg2(DbQuery):
|
303
|
+
def __init__(self, *args: T.Any, **kwargs: T.Any) -> None:
|
304
|
+
super().__init__(*args, **kwargs)
|
305
|
+
# Store the actual connection URI for pandas queries
|
306
|
+
self._connection_uri: str | None = None
|
307
|
+
|
308
|
+
def connect(self, use_ssh_tunnel: bool = False) -> None:
|
309
|
+
postgres_host: str
|
310
|
+
bind_port: int
|
311
|
+
|
312
|
+
if use_ssh_tunnel:
|
313
|
+
self._establish_ssh_tunnel()
|
314
|
+
assert self.ssh_tunnel is not None, "SSH tunnel is not active."
|
315
|
+
postgres_host = self.LOCAL_HOST
|
316
|
+
bind_port = self.ssh_tunnel.local_bind_port
|
317
|
+
else:
|
318
|
+
assert self.postgres_host is not None, "Postgres host is required"
|
319
|
+
assert self.postgres_port is not None, "Postgres port is required"
|
320
|
+
postgres_host = self.postgres_host
|
321
|
+
bind_port = self.postgres_port
|
322
|
+
|
323
|
+
# Store the actual connection URI for pandas
|
324
|
+
self._connection_uri = (
|
325
|
+
f"postgresql://{self.postgres_user}:{self.postgres_password}@"
|
326
|
+
f"{postgres_host}:{bind_port}/{self.postgres_database}"
|
327
|
+
)
|
328
|
+
|
329
|
+
log.print_normal(f"PostgreSQL bind port: {self.postgres_host} -> {self.postgres_port}")
|
330
|
+
# Connect to the PostgreSQL database
|
331
|
+
self.conn = psycopg2.connect(
|
332
|
+
host=postgres_host,
|
333
|
+
port=bind_port,
|
334
|
+
dbname=self.postgres_database,
|
335
|
+
user=self.postgres_user,
|
336
|
+
password=self.postgres_password,
|
337
|
+
)
|
338
|
+
log.print_ok_arrow("Connection to the database established successfully.")
|
339
|
+
|
340
|
+
def query(self, query: str, verbose: bool = False) -> pd.DataFrame:
|
341
|
+
"""Override query to use connection URI for pandas."""
|
342
|
+
start_time = time.time()
|
343
|
+
|
344
|
+
if verbose:
|
345
|
+
log.print_normal("=" * 80)
|
346
|
+
log.print_normal(query)
|
347
|
+
log.print_normal("=" * 80)
|
348
|
+
|
349
|
+
try:
|
350
|
+
# Use the actual connection URI to avoid pandas warning
|
351
|
+
connection_uri = self._connection_uri if self._connection_uri else self.postgres_uri
|
352
|
+
df = pd.read_sql_query(query, connection_uri)
|
353
|
+
time_delta = time.time() - start_time
|
354
|
+
if verbose:
|
355
|
+
log.print_ok_arrow(f"Time taken to query database: {time_delta:.2f} seconds")
|
356
|
+
return T.cast(pd.DataFrame, df)
|
357
|
+
except Exception as e: # pylint: disable=broad-except
|
358
|
+
log.print_normal(f"Error executing query: {e}")
|
359
|
+
|
360
|
+
return pd.DataFrame()
|
361
|
+
|
362
|
+
def clear(self, table: str) -> None:
|
363
|
+
if self.conn is None:
|
364
|
+
log.print_fail("Database connection is not active. Cannot clear table.")
|
365
|
+
return
|
366
|
+
|
367
|
+
conn = T.cast(psycopg2.extensions.connection, self.conn)
|
368
|
+
clear_query = f"DELETE FROM {table};"
|
369
|
+
try:
|
370
|
+
with conn.cursor() as cursor:
|
371
|
+
cursor.execute(clear_query)
|
372
|
+
conn.commit()
|
373
|
+
print(f"{table} table cleared successfully.")
|
374
|
+
except Exception as e: # pylint: disable=broad-except
|
375
|
+
print(f"Error clearing {table} table: {e}")
|
376
|
+
|
377
|
+
def load_tables(self, tables: T.List[str]) -> T.Dict[str, pd.DataFrame]:
|
378
|
+
dfs = {}
|
379
|
+
for table in tables:
|
380
|
+
table_name, df = self._load_table(table)
|
381
|
+
dfs[table_name] = df if df is not None else pd.DataFrame()
|
382
|
+
|
383
|
+
return dfs
|
384
|
+
|
385
|
+
def _load_table(self, table: str) -> T.Tuple[str, pd.DataFrame | None]:
|
386
|
+
try:
|
387
|
+
query = f"SELECT * FROM {table}"
|
388
|
+
return table, self.query(query)
|
389
|
+
except Exception as e: # pylint: disable=broad-except
|
390
|
+
print(f"Error loading data for {table}: {e}")
|
391
|
+
return table, None
|
392
|
+
|
393
|
+
|
394
|
+
# pylint: disable=too-many-arguments
|
395
|
+
# pylint: disable=too-many-positional-arguments
|
396
|
+
class DbQuerySpark(DbQuery):
|
397
|
+
JDBC_DRIVER_PATH = "/usr/share/java/postgresql.jar"
|
398
|
+
PARALLEL_LOAD = False
|
399
|
+
|
400
|
+
def __init__(
|
401
|
+
self,
|
402
|
+
postgres_host: str | None = None,
|
403
|
+
postgres_port: int | None = None,
|
404
|
+
postgres_database: str | None = None,
|
405
|
+
postgres_user: str | None = None,
|
406
|
+
postgres_password: str | None = None,
|
407
|
+
postrgres_url: str | None = None,
|
408
|
+
ssh_host: str | None = None,
|
409
|
+
ssh_port: int | None = None,
|
410
|
+
ssh_user: str | None = None,
|
411
|
+
ssh_pkey: str | None = None,
|
412
|
+
is_local: bool = False,
|
413
|
+
verbose: bool = False,
|
414
|
+
) -> None:
|
415
|
+
super().__init__(
|
416
|
+
postgres_host,
|
417
|
+
postgres_port,
|
418
|
+
postgres_database,
|
419
|
+
postgres_user,
|
420
|
+
postgres_password,
|
421
|
+
postrgres_url,
|
422
|
+
ssh_host,
|
423
|
+
ssh_port,
|
424
|
+
ssh_user,
|
425
|
+
ssh_pkey,
|
426
|
+
is_local,
|
427
|
+
verbose,
|
428
|
+
)
|
429
|
+
self.jdbc_url = (
|
430
|
+
f"jdbc:postgresql://{self.postgres_host}:{self.postgres_port}/{self.db_name}"
|
431
|
+
)
|
432
|
+
|
433
|
+
# Validate required parameters for Spark connection
|
434
|
+
assert self.postgres_user is not None, "Postgres user is required"
|
435
|
+
assert self.postgres_password is not None, "Postgres password is required"
|
436
|
+
|
437
|
+
self.connection_properties: T.Dict[str, str] = {
|
438
|
+
"user": self.postgres_user,
|
439
|
+
"password": self.postgres_password,
|
440
|
+
"driver": "org.postgresql.Driver",
|
441
|
+
}
|
442
|
+
|
443
|
+
self.conn = (
|
444
|
+
SparkSession.builder.appName("PostgreSQLConnection")
|
445
|
+
.config("spark.jars", self.JDBC_DRIVER_PATH)
|
446
|
+
.getOrCreate()
|
447
|
+
)
|
448
|
+
|
449
|
+
def clear(self, table: str) -> None:
|
450
|
+
if self.conn is None:
|
451
|
+
log.print_fail("Database connection is not active. Cannot clear table.")
|
452
|
+
return
|
453
|
+
try:
|
454
|
+
self.conn.sql(f"DROP TABLE IF EXISTS {table}") # type: ignore
|
455
|
+
log.print_ok_arrow(f"Table {table} cleared successfully.")
|
456
|
+
except Exception as e: # pylint: disable=broad-except
|
457
|
+
log.print_fail(f"Error clearing table {table}: {e}")
|
458
|
+
|
459
|
+
def connect(self, use_ssh_tunnel: bool = False) -> None:
|
460
|
+
postgres_host = self.postgres_host
|
461
|
+
bind_port = self.postgres_port
|
462
|
+
|
463
|
+
if use_ssh_tunnel:
|
464
|
+
self._establish_ssh_tunnel()
|
465
|
+
|
466
|
+
assert self.ssh_tunnel is not None, "SSH tunnel is not active."
|
467
|
+
|
468
|
+
postgres_host = self.LOCAL_HOST
|
469
|
+
bind_port = self.ssh_tunnel.local_bind_port
|
470
|
+
|
471
|
+
if self.is_local:
|
472
|
+
self.jdbc_url = (
|
473
|
+
f"jdbc:postgresql://{self.LOCAL_HOST}:{self.postgres_port}/{self.postgres_database}"
|
474
|
+
)
|
475
|
+
else:
|
476
|
+
self.jdbc_url = f"jdbc:postgresql://{postgres_host}:{bind_port}/{self.db_name}"
|
477
|
+
|
478
|
+
log.print_normal(f"PostgreSQL bind port: {postgres_host} -> {bind_port}")
|
479
|
+
|
480
|
+
def load_tables(self, tables: T.List[str]) -> T.Dict[str, pd.DataFrame]:
|
481
|
+
dfs = {}
|
482
|
+
if self.PARALLEL_LOAD:
|
483
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
484
|
+
future_to_table = {
|
485
|
+
executor.submit(self._load_table, table): table for table in tables
|
486
|
+
}
|
487
|
+
for future in concurrent.futures.as_completed(future_to_table):
|
488
|
+
table = future_to_table[future]
|
489
|
+
try:
|
490
|
+
table_name, df = future.result()
|
491
|
+
if df is not None:
|
492
|
+
dfs[table_name] = df
|
493
|
+
except Exception as e: # pylint: disable=broad-except
|
494
|
+
print(f"Error processing table {table}: {e}")
|
495
|
+
else:
|
496
|
+
for table in tables:
|
497
|
+
table_name, df = self._load_table(table)
|
498
|
+
dfs[table_name] = df if df is not None else pd.DataFrame()
|
499
|
+
|
500
|
+
return dfs
|
501
|
+
|
502
|
+
def _load_table(self, table: str) -> T.Tuple[str, pd.DataFrame | None]:
|
503
|
+
conn = T.cast(SparkSession, self.conn)
|
504
|
+
try:
|
505
|
+
spark_df = conn.read.jdbc(
|
506
|
+
url=self.jdbc_url, table=self.db(table), properties=self.connection_properties
|
507
|
+
)
|
508
|
+
return table, spark_df.toPandas()
|
509
|
+
except Exception as e: # pylint: disable=broad-except
|
510
|
+
print(f"Error loading data for {table}: {e}")
|
511
|
+
return table, None
|
@@ -24,4 +24,6 @@ src/ry_pg_utils/ipc/channels.py
|
|
24
24
|
src/ry_pg_utils/pb_types/__init__.py
|
25
25
|
src/ry_pg_utils/pb_types/database_pb2.py
|
26
26
|
src/ry_pg_utils/pb_types/database_pb2.pyi
|
27
|
-
src/ry_pg_utils/pb_types/py.typed
|
27
|
+
src/ry_pg_utils/pb_types/py.typed
|
28
|
+
src/ry_pg_utils/tools/__init__.py
|
29
|
+
src/ry_pg_utils/tools/db_query.py
|
ry_pg_utils-1.0.6/VERSION
DELETED
@@ -1 +0,0 @@
|
|
1
|
-
1.0.6
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|