qoptimizer-utils 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. qoptimizer_utils-0.1.0/PKG-INFO +21 -0
  2. qoptimizer_utils-0.1.0/db_utils/__init__.py +1 -0
  3. qoptimizer_utils-0.1.0/db_utils/clickhouse_access.py +178 -0
  4. qoptimizer_utils-0.1.0/db_utils/config.py +97 -0
  5. qoptimizer_utils-0.1.0/db_utils/db_access.py +9 -0
  6. qoptimizer_utils-0.1.0/db_utils/logger.py +89 -0
  7. qoptimizer_utils-0.1.0/db_utils/neo4j_access.py +146 -0
  8. qoptimizer_utils-0.1.0/db_utils/postgres/__init__.py +1 -0
  9. qoptimizer_utils-0.1.0/db_utils/postgres/connect.py +138 -0
  10. qoptimizer_utils-0.1.0/db_utils/postgres/helpers.py +303 -0
  11. qoptimizer_utils-0.1.0/db_utils/postgres/queries.py +16 -0
  12. qoptimizer_utils-0.1.0/db_utils/postgres/types.py +20 -0
  13. qoptimizer_utils-0.1.0/db_utils/postgres/writer.py +76 -0
  14. qoptimizer_utils-0.1.0/db_utils/snowflake_access.py +290 -0
  15. qoptimizer_utils-0.1.0/db_utils/supabase_access.py +151 -0
  16. qoptimizer_utils-0.1.0/pyproject.toml +34 -0
  17. qoptimizer_utils-0.1.0/qoptimizer_utils.egg-info/PKG-INFO +21 -0
  18. qoptimizer_utils-0.1.0/qoptimizer_utils.egg-info/SOURCES.txt +47 -0
  19. qoptimizer_utils-0.1.0/qoptimizer_utils.egg-info/dependency_links.txt +1 -0
  20. qoptimizer_utils-0.1.0/qoptimizer_utils.egg-info/requires.txt +10 -0
  21. qoptimizer_utils-0.1.0/qoptimizer_utils.egg-info/top_level.txt +2 -0
  22. qoptimizer_utils-0.1.0/setup.cfg +4 -0
  23. qoptimizer_utils-0.1.0/shared/__init__.py +0 -0
  24. qoptimizer_utils-0.1.0/shared/config.py +86 -0
  25. qoptimizer_utils-0.1.0/shared/connectors/__init__.py +0 -0
  26. qoptimizer_utils-0.1.0/shared/connectors/base.py +294 -0
  27. qoptimizer_utils-0.1.0/shared/connectors/bigquery.py +188 -0
  28. qoptimizer_utils-0.1.0/shared/connectors/clickhouse.py +370 -0
  29. qoptimizer_utils-0.1.0/shared/connectors/postgresql.py +299 -0
  30. qoptimizer_utils-0.1.0/shared/connectors/registry.py +22 -0
  31. qoptimizer_utils-0.1.0/shared/connectors/snowflake.py +267 -0
  32. qoptimizer_utils-0.1.0/shared/constants/__init__.py +15 -0
  33. qoptimizer_utils-0.1.0/shared/constants/issues.py +30 -0
  34. qoptimizer_utils-0.1.0/shared/constants/thresholds.py +22 -0
  35. qoptimizer_utils-0.1.0/shared/constants/vendors.py +3 -0
  36. qoptimizer_utils-0.1.0/shared/queries/__init__.py +0 -0
  37. qoptimizer_utils-0.1.0/shared/queries/bigquery.py +33 -0
  38. qoptimizer_utils-0.1.0/shared/queries/clickhouse.py +109 -0
  39. qoptimizer_utils-0.1.0/shared/queries/postgresql.py +82 -0
  40. qoptimizer_utils-0.1.0/shared/queries/snowflake.py +64 -0
  41. qoptimizer_utils-0.1.0/shared/security/__init__.py +0 -0
  42. qoptimizer_utils-0.1.0/shared/security/encryption.py +13 -0
  43. qoptimizer_utils-0.1.0/shared/services/__init__.py +0 -0
  44. qoptimizer_utils-0.1.0/shared/services/analysis.py +102 -0
  45. qoptimizer_utils-0.1.0/shared/utils/__init__.py +2 -0
  46. qoptimizer_utils-0.1.0/shared/utils/db.py +295 -0
  47. qoptimizer_utils-0.1.0/shared/utils/hashing.py +7 -0
  48. qoptimizer_utils-0.1.0/shared/utils/pg_snapshot_store.py +178 -0
  49. qoptimizer_utils-0.1.0/shared/utils/query_normalizer.py +16 -0
@@ -0,0 +1,21 @@
1
+ Metadata-Version: 2.4
2
+ Name: qoptimizer-utils
3
+ Version: 0.1.0
4
+ Summary: Shared utilities for Query Optimizer (shared/ + db_utils/)
5
+ License: Proprietary
6
+ Project-URL: Repository, https://github.com/dhkim77000/qoptimize_utils
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Programming Language :: Python :: 3.10
9
+ Classifier: Programming Language :: Python :: 3.11
10
+ Classifier: Programming Language :: Python :: 3.12
11
+ Requires-Python: >=3.10
12
+ Requires-Dist: pydantic>=2.10.0
13
+ Requires-Dist: pydantic-settings>=2.7.0
14
+ Requires-Dist: supabase>=2.11.0
15
+ Requires-Dist: cryptography>=44.0.0
16
+ Requires-Dist: structlog>=24.4.0
17
+ Requires-Dist: snowflake-connector-python>=3.12.0
18
+ Requires-Dist: clickhouse-driver>=0.2.9
19
+ Requires-Dist: google-cloud-bigquery>=3.27.0
20
+ Requires-Dist: asyncpg>=0.30.0
21
+ Requires-Dist: neo4j>=5.27.0
@@ -0,0 +1 @@
1
+ """db_utils package."""
@@ -0,0 +1,178 @@
1
+ """ClickHouseCloud connection code."""
2
+
3
+ import os
4
+ import sys
5
+
6
+ from clickhouse_driver import Client
7
+
8
+ _here = os.path.dirname(os.path.abspath(__file__))
9
+ if _here not in sys.path:
10
+ sys.path.insert(0, _here)
11
+ from db_utils.config import get_env_var
12
+ from db_utils.logger import _init_logger
13
+
14
+ logger = _init_logger(__file__)
15
+
16
+ DEFAULT_SERVICE = "rw-standard"
17
+ SERVICES = {
18
+ DEFAULT_SERVICE: "grkyl47mbo.us-west-2.aws.clickhouse.cloud",
19
+ "rw-burst": "el5bpab3my.us-west-2.aws.clickhouse.cloud",
20
+ }
21
+
22
+
23
+ def get_ch_config(service: str = DEFAULT_SERVICE, **kwargs: dict) -> dict:
24
+ """Get ClickHouse configuration.
25
+
26
+ Parameters
27
+ ----------
28
+ service : str, optional
29
+ The ClickHouse service to connect to, by default DEFAULT_SERVICE.
30
+ **kwargs : dict
31
+ Additional keyword arguments to override default configuration.
32
+
33
+ Returns
34
+ -------
35
+ dict
36
+ ClickHouse configuration
37
+ """
38
+ if service not in SERVICES:
39
+ msg = f"Invalid service name: {service}. Available services: {list(SERVICES.keys())}"
40
+ raise ValueError(msg)
41
+
42
+ logger.info(f"Connecting to ClickHouse service: {service}")
43
+
44
+ config = {
45
+ "host": SERVICES.get(service, SERVICES[DEFAULT_SERVICE]),
46
+ "username": get_env_var("ch_user", secret=True),
47
+ "password": get_env_var("ch_password", secret=True),
48
+ }
49
+ config.update(kwargs)
50
+ return config
51
+
52
+
53
+ def clickhouse_connect(**kwargs: dict) -> Client:
54
+ """Get a connection to ClickHouseCloud.
55
+
56
+ Returns
57
+ -------
58
+ clickhouse_driver.Client
59
+ ClickHouse connection client
60
+ """
61
+ config = get_ch_config(**kwargs)
62
+
63
+ return Client(
64
+ host=config["host"],
65
+ user=config["username"],
66
+ password=config["password"],
67
+ send_receive_timeout=3300,
68
+ secure=True,
69
+ )
70
+
71
+
72
+ def run_ch_sql(
73
+ query: str,
74
+ *,
75
+ with_column_types: bool = False,
76
+ return_dataframe: bool = False,
77
+ return_dict: bool = False,
78
+ **kwargs: dict,
79
+ ) -> list | tuple:
80
+ """Run a SQL query on ClickHouseCloud.
81
+
82
+ Parameters
83
+ ----------
84
+ query : str
85
+ The SQL query to be executed on ClickHouse.
86
+ with_column_types : bool, optional
87
+ If True, include column types in the results, by default False.
88
+ return_dataframe : bool, optional
89
+ If True, return results as a pandas DataFrame and column names, by default False.
90
+ return_dict : bool, optional
91
+ If True, return results as a list of dictionaries, by default False.
92
+ **kwargs
93
+ Additional keyword arguments to pass to the ClickHouse connection.
94
+
95
+ Returns
96
+ -------
97
+ list or tuple
98
+ list of tuples with query results
99
+
100
+ If return_dataframe is True:
101
+ tuple containing (DataFrame, list of column names)
102
+
103
+ Examples
104
+ --------
105
+ >>> run_ch_sql('query', return_dataframe=True)
106
+ '[ {name: 'John', age: 25}, {name: 'Jane', age: 22} ]'
107
+ """
108
+ con = clickhouse_connect(**kwargs)
109
+
110
+ if return_dataframe:
111
+ result = _run_ch_sql_as_dataframe(con, query)
112
+ elif return_dict:
113
+ result = _run_ch_sql_as_dict(con, query)
114
+ else:
115
+ result = con.execute(query, with_column_types=with_column_types)
116
+
117
+ con.disconnect_connection()
118
+ return result
119
+
120
+
121
+ def _run_ch_sql_as_dataframe(con: Client, query: str) -> tuple:
122
+ """Run a SQL query on ClickHouseCloud and return results as a DataFrame.
123
+
124
+ Parameters
125
+ ----------
126
+ con : clickhouse_driver.Client
127
+ The ClickHouse connection client.
128
+ query : str
129
+ The SQL query to be executed on ClickHouse.
130
+
131
+ Returns
132
+ -------
133
+ DataFrame, list
134
+ A tuple containing a pandas DataFrame and a list of column names.
135
+ """
136
+ from pandas import DataFrame # inline import since Pandas can be heavy
137
+
138
+ result, columns = con.execute(query, with_column_types=True)
139
+
140
+ column_names = [c[0] for c in columns]
141
+ dict_result = DataFrame(result, columns=column_names).to_dict(orient="records")
142
+
143
+ return dict_result, column_names
144
+
145
+
146
+ def _run_ch_sql_as_dict(con: Client, query: str) -> list[dict]:
147
+ """Run a SQL query on ClickHouseCloud and return results as a list of dictionaries.
148
+
149
+ Parameters
150
+ ----------
151
+ con : clickhouse_driver.Client
152
+ The ClickHouse connection client.
153
+ query : str
154
+ The SQL query to be executed on ClickHouse.
155
+
156
+ Returns
157
+ -------
158
+ list
159
+ A list of dictionaries representing the query results.
160
+ """
161
+ from clickhouse_driver.dbapi.extras import DictCursor
162
+
163
+ cur = DictCursor(con, connection=None)
164
+ cur.execute(query)
165
+ return cur.fetchall()
166
+
167
+
168
+ if __name__ == "__main__":
169
+ con = clickhouse_connect()
170
+ result = con.execute("select current_date();")
171
+ print(result)
172
+ con.disconnect_connection()
173
+
174
+ with clickhouse_connect() as ch_con:
175
+ result = ch_con.execute("select current_date();")
176
+ print(result)
177
+
178
+ print(run_ch_sql("select current_date();", return_results=True, service="rw-burst"))
@@ -0,0 +1,97 @@
1
+ """Configuration."""
2
+
3
+ import json
4
+ import os
5
+ from typing import Optional
6
+ from .logger import _init_logger
7
+
8
+ logger = _init_logger(__file__)
9
+
10
+ # Vendor-related env keys: always from AWS Secrets Manager (SecretId = user_id or AWS_VENDOR_SECRET_ID).
11
+
12
+
13
+ def _is_vendor_key(env_var: str) -> bool:
14
+ if env_var.startswith("PSQL_") or env_var.startswith("SF_") or env_var.startswith("BQ_"):
15
+ return True
16
+ if env_var in ("CH_HOST", "CH_USERNAME", "CH_PASSWORD", "CH_USER"):
17
+ return True
18
+ if env_var.startswith("CH_") and not env_var.startswith("CH_ANALYTICS_"):
19
+ return True
20
+ return False
21
+
22
+
23
+ def _get_secret_from_aws(secret_id: str) -> dict:
24
+ """Fetch secret JSON from AWS Secrets Manager."""
25
+ import boto3
26
+ from botocore.exceptions import ClientError
27
+ region = os.environ.get("AWS_REGION", "us-east-1")
28
+ client = boto3.client("secretsmanager", region_name=region)
29
+ try:
30
+ resp = client.get_secret_value(SecretId=secret_id)
31
+ raw = resp.get("SecretString") or "{}"
32
+ return json.loads(raw)
33
+ except ClientError as e:
34
+ logger.debug("AWS SM get_secret_value %s: %s", secret_id, e)
35
+ return {}
36
+ except json.JSONDecodeError:
37
+ return {}
38
+
39
+
40
+ def get_chartmetric_data_script_path() -> str:
41
+ """Get local directory path for chartmetric_data_script.
42
+
43
+ Returns
44
+ -------
45
+ str
46
+ local directory path
47
+ """
48
+ return os.environ.get("CHARTMETRIC_DATA_SCRIPT_PATH", "/home/ec2-user/chartmetric_data_script")
49
+
50
+
51
+ def get_env_var(name: str, secret: Optional[bool] = None, user_id: Optional[str] = None) -> Optional[str]:
52
+ """Get value from AWS Secrets Manager (vendor keys) or environment.
53
+
54
+ Vendor-related keys (PSQL_*, SF_*, CH_* except CH_ANALYTICS_*, BQ_*):
55
+ - Fetched from AWS SM. SecretId = user_id or env AWS_VENDOR_SECRET_ID.
56
+ Other keys:
57
+ - From os.environ only.
58
+
59
+ Parameters
60
+ ----------
61
+ name : str
62
+ variable name (e.g. psql_user, ch_password)
63
+ secret : bool, optional
64
+ unused; kept for API compatibility
65
+ user_id : str, optional
66
+ user UUID for per-user secret; if None, uses AWS_VENDOR_SECRET_ID (system secret)
67
+
68
+ Returns
69
+ -------
70
+ str or None
71
+ value if found, else None
72
+ """
73
+ assert name, "name cannot be empty."
74
+ env_var = name.upper()
75
+
76
+ if _is_vendor_key(env_var):
77
+ secret_id = user_id or os.environ.get("AWS_VENDOR_SECRET_ID")
78
+ if secret_id:
79
+ secrets = _get_secret_from_aws(secret_id)
80
+ if secrets:
81
+ val = secrets.get(env_var) or (secrets.get("CH_USERNAME") if env_var == "CH_USER" else None)
82
+ if val is not None and str(val).strip():
83
+ logger.debug("Using %s from AWS SM (secret_id=%s)", env_var, secret_id)
84
+ return str(val).strip()
85
+
86
+ if value := os.environ.get(env_var):
87
+ logger.info("Found %s in environment variables.", env_var)
88
+ return value
89
+
90
+ logger.error("Could not find %s in AWS SM or environment.", name)
91
+ return None
92
+
93
+
94
+ if __name__ == "__main__":
95
+ # test with python3 -m 'src.db_utils.config'
96
+ assert get_env_var("psql_writer_host") == "prod2.cluster-cni52ceaa2ty.us-west-2.rds.amazonaws.com"
97
+ assert get_env_var("chartmetric_api_username") == "data-eng@chartmetric.com"
@@ -0,0 +1,9 @@
1
+ """DB access."""
2
+
3
+ from .postgres.connect import db_connect as db_reader_connect
4
+ from .postgres.connect import get_connection as get_reader_db_connection
5
+ from .postgres.connect import get_cursor as get_reader_db_cursor
6
+ from .postgres.connect import run_pg_sql
7
+ from .postgres.helpers import insert_data
8
+ from .postgres.helpers import disconnect as db_disconnect
9
+ from .postgres.writer import db_connect, get_db_connection, get_db_cursor
@@ -0,0 +1,89 @@
1
+ """Implements functions for logging."""
2
+
3
+ import datetime
4
+ import logging
5
+
6
+
7
+ class Logger:
8
+ """Logger with info, skip, select, insert, update, delete, error, exception, stack_trace."""
9
+
10
+ def __init__(self, name: str) -> None:
11
+ self._log = logging.getLogger(name)
12
+
13
+ def info(self, msg: str, *args) -> None:
14
+ self._log.info(msg, *args)
15
+
16
+ def skip(self, msg: str) -> None:
17
+ self._log.info("skip: %s", msg)
18
+
19
+ def select(self, msg: str) -> None:
20
+ self._log.info("select: %s", msg)
21
+
22
+ def insert(self, msg: str) -> None:
23
+ self._log.info("insert: %s", msg)
24
+
25
+ def update(self, msg: str) -> None:
26
+ self._log.info("update: %s", msg)
27
+
28
+ def delete(self, msg: str) -> None:
29
+ self._log.info("delete: %s", msg)
30
+
31
+ def debug(self, msg: str, *args) -> None:
32
+ self._log.debug(msg, *args)
33
+
34
+ def error(self, msg: str, *args) -> None:
35
+ self._log.error(msg, *args)
36
+
37
+ def exception(self, msg: str) -> None:
38
+ self._log.exception(msg)
39
+
40
+ def stack_trace(self) -> None:
41
+ self._log.info("stack_trace", exc_info=True)
42
+
43
+
44
+
45
+ def _init_logger(name: str) -> Logger:
46
+ """Initialize the logger.
47
+
48
+ Parameters
49
+ ----------
50
+ name : str
51
+ Name of the logger
52
+
53
+ Returns
54
+ -------
55
+ Logger
56
+ Initialized logger instance
57
+ """
58
+ return Logger(name)
59
+
60
+
61
+ def log_with_print(message: str) -> None:
62
+ """Print the message with a timestamp prepended to it.
63
+
64
+ Parameters
65
+ ----------
66
+ message : str
67
+ message to log
68
+ """
69
+ print(f"{datetime.datetime.now(tz=datetime.timezone.utc)} {message}")
70
+
71
+
72
+ if __name__ == "__main__":
73
+ logger = _init_logger(__file__)
74
+
75
+ logger.info("start processing")
76
+ logger.skip("update=False - spotify_track_id=56Kjskx12ksQss")
77
+ logger.select("spotify_artist - tier=1 - artists_count=30")
78
+ logger.insert("spotify_artist_insights - id=3963, timestp=2022-08-30")
79
+ logger.update("spotify_artist_insights - id=3963, timestp=2022-08-30")
80
+ logger.delete("spotify_artist_insights - id=3963, timestp=2022-08-30")
81
+ logger.error("spotify_artist - tier=1 - artists_count=30")
82
+
83
+ try:
84
+ x = [1, 2, 3]
85
+ print(x[5]) # This will raise an IndexError
86
+ except IndexError:
87
+ logger.exception("As expected, an error occurred while accessing the list.")
88
+
89
+ logger.stack_trace()
@@ -0,0 +1,146 @@
1
+ """Neo4j connection code."""
2
+
3
+ from contextlib import contextmanager
4
+
5
+ from neo4j import GraphDatabase
6
+
7
+ from .config import get_env_var
8
+ from .logger import _init_logger
9
+
10
+ logger = _init_logger(__file__)
11
+
12
+
13
+ def user_id_to_neo4j_db(user_id: int | str) -> str:
14
+ """Convert a user id to a Neo4j database name.
15
+
16
+ Neo4j database names must be lowercase alphanumeric + dots.
17
+ Format: ``user.<id>`` for int ids, ``user.<uuid_hex>`` for legacy UUID strings.
18
+
19
+ Example: ``42`` → ``"user.42"``
20
+ """
21
+ if isinstance(user_id, int):
22
+ return f"user.{user_id}"
23
+ # Legacy UUID string support
24
+ hex_id = str(user_id).replace("-", "").lower()
25
+ return f"user.{hex_id}"
26
+
27
+
28
+ def ensure_user_database(driver, user_id: str) -> str:
29
+ """Create a per-user Neo4j database if it does not exist.
30
+
31
+ Must run against the ``system`` database.
32
+ Returns the database name, or ``None`` if the instance does not support
33
+ multi-database (e.g. Neo4j Aura Free/Pro).
34
+ """
35
+ db_name = user_id_to_neo4j_db(user_id)
36
+ try:
37
+ with driver.session(database="system") as session:
38
+ session.run(f"CREATE DATABASE `{db_name}` IF NOT EXISTS")
39
+ logger.info("ensure_user_database: created %s", db_name)
40
+ return db_name
41
+ except Exception as e:
42
+ if "not supported" in str(e).lower() or "UnsupportedAdministration" in str(e):
43
+ logger.info(
44
+ "ensure_user_database: CREATE DATABASE not supported "
45
+ "(Neo4j Aura?). Using default database.",
46
+ )
47
+ return None
48
+ raise
49
+
50
+
51
+ def get_neo4j_config(**kwargs) -> dict:
52
+ """Get Neo4j configuration.
53
+
54
+ Parameters
55
+ ----------
56
+ kwargs
57
+ Customizable Neo4j configuration keys:
58
+ - uri
59
+ - user
60
+ - password
61
+
62
+ Returns
63
+ -------
64
+ dict
65
+ Neo4j configuration
66
+ """
67
+ config = {
68
+ "uri": get_env_var("neo4j_uri"),
69
+ "user": get_env_var("neo4j_user") or "neo4j",
70
+ "password": get_env_var("neo4j_password", secret=True),
71
+ }
72
+ config.update(kwargs)
73
+ return config
74
+
75
+
76
+ def neo4j_connect(**kwargs):
77
+ """Get a Neo4j driver.
78
+
79
+ Returns
80
+ -------
81
+ neo4j.BoltDriver
82
+ Neo4j driver
83
+ """
84
+ config = get_neo4j_config(**kwargs)
85
+ return GraphDatabase.driver(
86
+ config["uri"],
87
+ auth=(config["user"], config["password"]),
88
+ )
89
+
90
+
91
+ @contextmanager
92
+ def get_neo4j_session(database: str | None = None, **kwargs):
93
+ """Get Neo4j session context (auto-closes driver).
94
+
95
+ Parameters
96
+ ----------
97
+ database : str, optional
98
+ Target Neo4j database name. ``None`` uses the server default.
99
+
100
+ Yields
101
+ ------
102
+ neo4j.Session
103
+ Neo4j session
104
+ """
105
+ driver = neo4j_connect(**kwargs)
106
+ try:
107
+ session_kwargs = {}
108
+ if database:
109
+ session_kwargs["database"] = database
110
+ with driver.session(**session_kwargs) as session:
111
+ yield session
112
+ finally:
113
+ driver.close()
114
+
115
+
116
+ def run_neo4j_query(query: str, params: dict = None, database: str | None = None, **kwargs):
117
+ """Run a Cypher query and return results.
118
+
119
+ Parameters
120
+ ----------
121
+ query : str
122
+ Cypher query
123
+ params : dict, optional
124
+ Query parameters
125
+ database : str, optional
126
+ Target Neo4j database name.
127
+ **kwargs
128
+ Additional keyword arguments for connection
129
+
130
+ Returns
131
+ -------
132
+ list
133
+ List of record dicts
134
+ """
135
+ with get_neo4j_session(database=database, **kwargs) as session:
136
+ result = session.run(query, params or {})
137
+ return [dict(record) for record in result]
138
+
139
+
140
+ if __name__ == "__main__":
141
+ driver = neo4j_connect()
142
+ with driver.session() as s:
143
+ result = s.run("MATCH (n) RETURN labels(n)[0] AS label, count(n) AS cnt ORDER BY label")
144
+ for rec in result:
145
+ print(f" {rec['label']}: {rec['cnt']}")
146
+ driver.close()
@@ -0,0 +1 @@
1
+ """Postgres."""
@@ -0,0 +1,138 @@
1
+ """Postgres connection code."""
2
+
3
+ from collections.abc import Generator
4
+ from contextlib import contextmanager
5
+ from typing import Any, Optional
6
+
7
+ import psycopg2
8
+ from psycopg2 import connect as _connect
9
+ from psycopg2.extensions import UNICODE, connection
10
+ from psycopg2.extras import RealDictCursor
11
+ from psycopg2.pool import ThreadedConnectionPool
12
+
13
+ from .helpers import MAX_CLIENTS, get_config
14
+
15
+ psycopg2.extensions.register_type(UNICODE)
16
+
17
+
18
+ def db_connect(**kwargs: dict) -> connection:
19
+ """Connect Postgres (Reader).
20
+
21
+ Parameters
22
+ ----------
23
+ user : str, optional
24
+ Postgres user, by default DB_USER
25
+
26
+ Returns
27
+ -------
28
+ psycopg2.extensions.connection
29
+ Postgres connection
30
+ """
31
+ config = get_config(**kwargs)
32
+ return _connect(
33
+ database=config["database"],
34
+ user=config["user"],
35
+ password=config["password"],
36
+ host=config["host"],
37
+ port=config["port"],
38
+ )
39
+
40
+
41
+ @contextmanager
42
+ def get_connection(**kwargs: dict) -> Generator[connection, None, None]:
43
+ """Get Postgres connection (Reader).
44
+
45
+ Parameters
46
+ ----------
47
+ user : str, optional
48
+ Postgres user, by default DB_USER
49
+
50
+ Yields
51
+ ------
52
+ psycopg2.extensions.connection
53
+ Postgres connection
54
+ """
55
+ config = get_config(**kwargs)
56
+ pool = ThreadedConnectionPool(
57
+ 1,
58
+ MAX_CLIENTS,
59
+ database=config["database"],
60
+ user=config["user"],
61
+ password=config["password"],
62
+ host=config["host"],
63
+ port=config["port"],
64
+ )
65
+ try:
66
+ conn = pool.getconn()
67
+ yield conn
68
+ finally:
69
+ pool.putconn(conn, close=True)
70
+
71
+
72
+ @contextmanager
73
+ def get_cursor(commit: Optional[bool] = None, **kwargs: dict) -> Generator[RealDictCursor, None, None]:
74
+ """Get Postgres cursor (Reader).
75
+
76
+ Parameters
77
+ ----------
78
+ commit : boolean, optional
79
+ whether to auto-commit at the end, by default False
80
+
81
+ Yields
82
+ ------
83
+ psycopg2.extras.RealDictCursor
84
+ Postgres cursor
85
+ """
86
+ with get_connection(**kwargs) as conn:
87
+ cursor = conn.cursor(cursor_factory=RealDictCursor)
88
+ try:
89
+ yield cursor
90
+ if commit:
91
+ conn.commit()
92
+ finally:
93
+ cursor.close()
94
+
95
+
96
+ def run_pg_sql(
97
+ query: str, *, writer: bool = False, return_results: bool = False, **kwargs: dict
98
+ ) -> Optional[list[dict[str, Any]]]:
99
+ """Get data from the database for the specified query.
100
+
101
+ Parameters
102
+ ----------
103
+ query : str
104
+ Query
105
+ user : str, optional
106
+ Postgres user, by default DB_USER
107
+ writer : bool, optional
108
+ whether to connect to the writer node, by default False
109
+ return_results : bool, optional
110
+ whether to return the query results, by default False
111
+
112
+ Returns
113
+ -------
114
+ list
115
+ list of {column: value} dicts
116
+ """
117
+ with get_connection(writer=writer, **kwargs) as con, con.cursor(cursor_factory=RealDictCursor) as cur:
118
+ cur.execute(query)
119
+ if writer:
120
+ con.commit()
121
+ if return_results:
122
+ return cur.fetchall()
123
+ return None
124
+
125
+
126
+ if __name__ == "__main__":
127
+ con = db_connect()
128
+ cur = con.cursor()
129
+ cur.execute("select current_date;")
130
+ print(cur.fetchall())
131
+ con.close()
132
+
133
+ with get_connection() as pg_con:
134
+ cur = pg_con.cursor()
135
+ cur.execute("select current_date;")
136
+ print(cur.fetchall())
137
+
138
+ print(run_pg_sql("select current_date;", return_results=True))