ry-pg-utils 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
ry_pg_utils/config.py ADDED
@@ -0,0 +1,44 @@
1
+ import os
2
+ import socket
3
+ from dataclasses import dataclass
4
+
5
+ import dotenv
6
+
7
+
8
+ @dataclass
9
+ class Config:
10
+ postgres_host: str | None
11
+ postgres_port: int | None
12
+ postgres_db: str | None
13
+ postgres_user: str | None
14
+ postgres_password: str | None
15
+ do_publish_db: bool
16
+ use_local_db_only: bool
17
+ backend_id: str
18
+ add_backend_to_all: bool
19
+ add_backend_to_tables: bool
20
+ raise_on_use_before_init: bool
21
+
22
+
23
+ dotenv.load_dotenv()
24
+
25
+ # Parse POSTGRES_PORT with proper None handling for mypy
26
+ _postgres_port_str = os.getenv("POSTGRES_PORT")
27
+ _postgres_port = int(_postgres_port_str) if _postgres_port_str is not None else None
28
+
29
+ pg_config = Config(
30
+ postgres_host=os.getenv("POSTGRES_HOST"),
31
+ postgres_port=_postgres_port,
32
+ postgres_db=os.getenv("POSTGRES_DB"),
33
+ postgres_user=os.getenv("POSTGRES_USER"),
34
+ postgres_password=os.getenv("POSTGRES_PASSWORD"),
35
+ do_publish_db=False,
36
+ use_local_db_only=True,
37
+ backend_id=(
38
+ os.getenv("POSTGRES_USER")
39
+ or f"{socket.gethostname()}_{socket.gethostbyname(socket.gethostname())}"
40
+ ),
41
+ add_backend_to_all=True,
42
+ add_backend_to_tables=True,
43
+ raise_on_use_before_init=True,
44
+ )
ry_pg_utils/connect.py ADDED
@@ -0,0 +1,288 @@
1
+ import contextlib
2
+ import importlib
3
+ import pkgutil
4
+ import threading
5
+ import typing as T
6
+
7
+ from ryutils import log
8
+ from sqlalchemy import Column, String, create_engine, event
9
+ from sqlalchemy.engine.base import Engine
10
+ from sqlalchemy.exc import OperationalError
11
+ from sqlalchemy.orm import declarative_base, declared_attr, scoped_session, sessionmaker
12
+ from sqlalchemy.orm.decl_api import DeclarativeMeta
13
+ from sqlalchemy.orm.scoping import ScopedSession
14
+ from sqlalchemy_utils import database_exists
15
+ from tenacity import retry, stop_after_attempt, wait_exponential
16
+
17
+ from ry_pg_utils import config
18
+
19
+ _thread_local = threading.local()
20
+ BACKEND_ID_VARIABLE = "backend_id"
21
+
22
+ ENGINE: T.Dict[str, Engine] = {}
23
+ THREAD_SAFE_SESSION_FACTORY: T.Dict[str, ScopedSession] = {}
24
+
25
+ # Base class with optional backend_id field
26
+ if config.pg_config.add_backend_to_all:
27
+ # Add any common fields here
28
+ class CommonBaseModel:
29
+ @declared_attr
30
+ def backend_id(cls: T.Any) -> Column: # pylint: disable=no-self-argument
31
+ return Column(String(256), nullable=False)
32
+
33
+ Base = declarative_base(name="Base", cls=CommonBaseModel)
34
+ else:
35
+ Base = declarative_base(name="Base")
36
+
37
+ # Add type annotation for Base
38
+ Base: DeclarativeMeta # type: ignore
39
+
40
+
41
+ def get_table_name(
42
+ base_name: str, verbose: bool = False, backend_id: str = config.pg_config.backend_id
43
+ ) -> str:
44
+ if verbose:
45
+ print(
46
+ f"{base_name}_{backend_id}"
47
+ if config.pg_config.add_backend_to_tables
48
+ else f"{base_name}"
49
+ )
50
+ return f"{base_name}_{backend_id}" if config.pg_config.add_backend_to_tables else base_name
51
+
52
+
53
+ def init_engine(uri: str, db: str, **kwargs: T.Any) -> Engine:
54
+ global ENGINE # pylint: disable=global-variable-not-assigned
55
+ if db not in ENGINE:
56
+ # Add pool settings to automatically recycle connections
57
+ default_pool_settings = {
58
+ "pool_recycle": 3600, # Recycle connections after 1 hour
59
+ "pool_pre_ping": True, # Enable connection health checks
60
+ "pool_size": 5, # Maintain a pool of connections
61
+ "max_overflow": 10, # Allow up to 10 additional connections
62
+ }
63
+ # Update kwargs with defaults if not already set
64
+ for key, value in default_pool_settings.items():
65
+ kwargs.setdefault(key, value)
66
+ ENGINE[db] = create_engine(uri, **kwargs)
67
+ return ENGINE[db]
68
+
69
+
70
+ def get_engine(db: str) -> Engine:
71
+ global ENGINE # pylint: disable=global-variable-not-assigned
72
+ return ENGINE[db]
73
+
74
+
75
+ def clear_db() -> None:
76
+ global ENGINE # pylint: disable=global-statement
77
+ global THREAD_SAFE_SESSION_FACTORY # pylint: disable=global-statement
78
+ ENGINE = {}
79
+ THREAD_SAFE_SESSION_FACTORY = {}
80
+
81
+
82
+ def close_engine(db: str) -> None:
83
+ global ENGINE # pylint: disable=global-statement, global-variable-not-assigned
84
+ global THREAD_SAFE_SESSION_FACTORY # pylint: disable=global-statement, global-variable-not-assigned
85
+ if db in ENGINE:
86
+ ENGINE[db].dispose()
87
+ del ENGINE[db]
88
+ if db in THREAD_SAFE_SESSION_FACTORY:
89
+ del THREAD_SAFE_SESSION_FACTORY[db]
90
+
91
+
92
+ @retry(
93
+ stop=stop_after_attempt(3),
94
+ wait=wait_exponential(multiplier=1, min=4, max=10),
95
+ retry=lambda e: isinstance(e, (OperationalError, TimeoutError)),
96
+ )
97
+ def _init_session_factory(db: str) -> ScopedSession:
98
+ """Initialize the THREAD_SAFE_SESSION_FACTORY."""
99
+ global ENGINE, THREAD_SAFE_SESSION_FACTORY # pylint: disable=global-variable-not-assigned
100
+ if db not in ENGINE:
101
+ raise ValueError(
102
+ "Initialize ENGINE by calling init_engine before calling _init_session_factory!"
103
+ )
104
+ if db not in THREAD_SAFE_SESSION_FACTORY:
105
+ session_factory = sessionmaker(bind=ENGINE[db])
106
+ THREAD_SAFE_SESSION_FACTORY[db] = scoped_session(session_factory)
107
+ return THREAD_SAFE_SESSION_FACTORY[db]
108
+
109
+
110
+ def set_backend_id(backend_id: str) -> None:
111
+ setattr(_thread_local, BACKEND_ID_VARIABLE, backend_id)
112
+
113
+
114
+ def get_backend_id() -> T.Optional[str]:
115
+ return getattr(_thread_local, BACKEND_ID_VARIABLE, None)
116
+
117
+
118
+ @event.listens_for(scoped_session, "before_flush")
119
+ def receive_before_flush(session: ScopedSession, _flush_context: T.Any, _instances: T.Any) -> None:
120
+ backend_id = get_backend_id()
121
+ if not backend_id:
122
+ return
123
+
124
+ # Automatically add backend_id to instances that have it as a field
125
+ for instance in session.dirty:
126
+ if hasattr(instance, "backend_id") and instance.backend_id is None:
127
+ instance.backend_id = backend_id
128
+
129
+ for instance in session.new:
130
+ if hasattr(instance, "backend_id") and instance.backend_id is None:
131
+ instance.backend_id = backend_id
132
+
133
+
134
+ def is_session_factory_initialized() -> bool:
135
+ return bool(THREAD_SAFE_SESSION_FACTORY)
136
+
137
+
138
+ @contextlib.contextmanager
139
+ def ManagedSession( # pylint: disable=invalid-name
140
+ db: T.Optional[str] = None, backend_id: T.Optional[str] = config.pg_config.backend_id
141
+ ) -> T.Iterator[T.Optional[ScopedSession]]:
142
+ """Get a session object whose lifecycle, commits and flush are managed for you.
143
+ The session will automatically retry operations on connection errors.
144
+
145
+ Expected to be used as follows:
146
+ ```
147
+ # multiple db_operations are done within one session.
148
+ with ManagedSession() as session:
149
+ # db_operations is expected not to worry about session handling.
150
+ db_operations.select(session, **kwargs)
151
+ # after the with statement, the session commits to the database.
152
+ db_operations.insert(session, **kwargs)
153
+ ```
154
+ """
155
+ global THREAD_SAFE_SESSION_FACTORY # pylint: disable=global-variable-not-assigned
156
+ if db is None:
157
+ # assume we're just using the default db
158
+ db = list(THREAD_SAFE_SESSION_FACTORY.keys())[0]
159
+
160
+ if db not in THREAD_SAFE_SESSION_FACTORY:
161
+ if config.pg_config.raise_on_use_before_init:
162
+ raise ValueError(f"Call _init_session_factory for {db} before using ManagedSession!")
163
+ log.print_fail(f"Call _init_session_factory for {db} before using ManagedSession!")
164
+ yield None
165
+ return
166
+
167
+ @retry(
168
+ stop=stop_after_attempt(3),
169
+ wait=wait_exponential(multiplier=1, min=4, max=10),
170
+ retry=lambda e: isinstance(e, OperationalError),
171
+ )
172
+ def execute_with_retry(session: ScopedSession) -> T.Iterator[ScopedSession]:
173
+ try:
174
+ yield session
175
+ session.commit()
176
+ session.flush()
177
+ except Exception:
178
+ session.rollback()
179
+ raise
180
+
181
+ session = THREAD_SAFE_SESSION_FACTORY[db]()
182
+
183
+ if backend_id:
184
+ set_backend_id(backend_id)
185
+
186
+ try:
187
+ yield from execute_with_retry(session)
188
+ finally:
189
+ # source:
190
+ # https://stackoverflow.com/questions/
191
+ # 21078696/why-is-my-scoped-session-raising-an-attributeerror-session-object-has-no-attr
192
+ THREAD_SAFE_SESSION_FACTORY[db].remove()
193
+
194
+
195
+ def is_database_initialized(db: str) -> bool:
196
+ """Check if the database is initialized."""
197
+ global THREAD_SAFE_SESSION_FACTORY # pylint: disable=global-variable-not-assigned
198
+ return db in THREAD_SAFE_SESSION_FACTORY
199
+
200
+
201
+ def _import_models_from_module(module_path: str) -> None:
202
+ """
203
+ Dynamically import all model classes from a given module path.
204
+
205
+ This ensures all SQLAlchemy models are registered with Base.metadata
206
+ before creating tables.
207
+
208
+ Args:
209
+ module_path: Dot-separated module path (e.g., 'database.models')
210
+ """
211
+ try:
212
+ # Import the base module
213
+ base_module = importlib.import_module(module_path)
214
+
215
+ # Get the package path
216
+ if hasattr(base_module, "__path__"):
217
+ package_path = base_module.__path__
218
+ else:
219
+ # It's a single module, not a package
220
+ return
221
+
222
+ # Walk through all submodules
223
+ for _, modname, _ in pkgutil.walk_packages(
224
+ path=package_path,
225
+ prefix=f"{module_path}.",
226
+ ):
227
+ try:
228
+ importlib.import_module(modname)
229
+ except Exception as e: # pylint: disable=broad-except
230
+ log.print_warn(f"Failed to import {modname}: {e}")
231
+ continue
232
+
233
+ log.print_ok_blue(f"Imported models from {module_path}")
234
+ except ModuleNotFoundError:
235
+ log.print_warn(f"Models module not found: {module_path}")
236
+ except Exception as e: # pylint: disable=broad-except
237
+ log.print_warn(f"Error importing models from {module_path}: {e}")
238
+
239
+
240
+ def init_database(
241
+ db_name: str,
242
+ db_user: str = "",
243
+ db_password: str = "",
244
+ db_host: str = "localhost",
245
+ db_port: int = 5432,
246
+ models_module: T.Optional[str] = None,
247
+ ) -> None:
248
+ """
249
+ Initialize a database connection and create tables.
250
+
251
+ Args:
252
+ db_name: Name of the database
253
+ db_user: Database username
254
+ db_password: Database password
255
+ db_host: Database host
256
+ db_port: Database port
257
+ models_module: Optional dot-separated module path to your models
258
+ (e.g., 'database.models' or 'myapp.db.models').
259
+ All model classes in this module will be automatically
260
+ imported to register them with SQLAlchemy.
261
+ """
262
+ log.print_normal(f"Initializing database {db_name} at {db_host}:{db_port}")
263
+
264
+ if db_user and db_password:
265
+ uri = f"postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}"
266
+ elif db_user:
267
+ uri = f"postgresql://{db_user}@{db_host}:{db_port}/{db_name}"
268
+ else:
269
+ uri = f"postgresql://{db_host}:{db_port}/{db_name}"
270
+
271
+ engine = init_engine(uri, db_name)
272
+
273
+ if database_exists(engine.url):
274
+ log.print_normal("Found existing database")
275
+ else:
276
+ log.print_ok_blue("Creating new database!")
277
+
278
+ # Import models if module path provided
279
+ if models_module:
280
+ _import_models_from_module(models_module)
281
+
282
+ try:
283
+ Base.metadata.create_all(bind=engine)
284
+
285
+ _init_session_factory(db_name)
286
+ except OperationalError as exc:
287
+ log.print_fail(f"Failed to initialize database: {exc}")
288
+ log.print_normal("Continuing without db connection...")
@@ -0,0 +1,199 @@
1
+ import typing as T
2
+ from datetime import datetime, timezone
3
+
4
+ from google.protobuf.descriptor import FieldDescriptor
5
+ from google.protobuf.message import Message
6
+ from ryutils import log
7
+ from sqlalchemy import (
8
+ Column,
9
+ DateTime,
10
+ Integer,
11
+ LargeBinary,
12
+ MetaData,
13
+ String,
14
+ Table,
15
+ func,
16
+ insert,
17
+ inspect,
18
+ types,
19
+ )
20
+ from sqlalchemy.engine import Engine
21
+ from sqlalchemy.exc import OperationalError
22
+
23
+ from ry_pg_utils.config import pg_config
24
+ from ry_pg_utils.connect import (
25
+ BACKEND_ID_VARIABLE,
26
+ ENGINE,
27
+ ManagedSession,
28
+ get_table_name,
29
+ init_engine,
30
+ )
31
+
32
+ FIELD_TYPE_MAP: T.Dict[int, str] = {
33
+ FieldDescriptor.TYPE_DOUBLE: "float",
34
+ FieldDescriptor.TYPE_FLOAT: "float",
35
+ FieldDescriptor.TYPE_INT64: "int",
36
+ FieldDescriptor.TYPE_UINT64: "int",
37
+ FieldDescriptor.TYPE_INT32: "int",
38
+ FieldDescriptor.TYPE_FIXED64: "int",
39
+ FieldDescriptor.TYPE_FIXED32: "int",
40
+ FieldDescriptor.TYPE_BOOL: "bool",
41
+ FieldDescriptor.TYPE_STRING: "string",
42
+ FieldDescriptor.TYPE_GROUP: "string",
43
+ FieldDescriptor.TYPE_MESSAGE: "message",
44
+ FieldDescriptor.TYPE_BYTES: "binary",
45
+ FieldDescriptor.TYPE_UINT32: "int",
46
+ FieldDescriptor.TYPE_ENUM: "int",
47
+ FieldDescriptor.TYPE_SFIXED32: "int",
48
+ FieldDescriptor.TYPE_SFIXED64: "int",
49
+ FieldDescriptor.TYPE_SINT32: "int",
50
+ FieldDescriptor.TYPE_SINT64: "int",
51
+ }
52
+
53
+ ADD_BACKEND_TO_ALL = pg_config.add_backend_to_all
54
+
55
+
56
+ def _get_field_types(message_class: T.Type) -> T.Dict[str, str]:
57
+ """Retrieve the types of each field in a Protobuf message class."""
58
+ return {field.name: FIELD_TYPE_MAP[field.type] for field in message_class.DESCRIPTOR.fields}
59
+
60
+
61
+ def _combine_pb_timestamp(seconds: int, nanos: int) -> datetime:
62
+ """Combines Protobuf seconds and nanos into a single datetime object."""
63
+ total_seconds = seconds + nanos / 1e9
64
+ return datetime.fromtimestamp(total_seconds, tz=timezone.utc)
65
+
66
+
67
+ def _create_dynamic_table(channel_name: str, pb_message: T.Type, db_name: str) -> Table:
68
+ engine: T.Optional[Engine] = ENGINE.get(db_name)
69
+ if not engine:
70
+ raise ValueError(f"Database {db_name!r} not initialized")
71
+ if not pb_message:
72
+ raise ValueError(f"Protobuf message {pb_message!r} invalid")
73
+
74
+ # attempt inspect
75
+ try:
76
+ with engine.connect() as conn:
77
+ inspector = inspect(conn)
78
+ existing = inspector.get_table_names()
79
+ except OperationalError:
80
+ engine.dispose()
81
+ engine = init_engine(str(engine.url), db_name)
82
+ with engine.connect() as conn:
83
+ inspector = inspect(conn)
84
+ existing = inspector.get_table_names()
85
+
86
+ tbl_name = get_table_name(channel_name)
87
+ metadata = MetaData()
88
+ if tbl_name in existing:
89
+ return Table(tbl_name, metadata, autoload_with=engine)
90
+
91
+ cols: T.List[Column] = [
92
+ Column("key", Integer, primary_key=True, autoincrement=True),
93
+ Column("created_at", DateTime, server_default=func.now()), # pylint: disable=not-callable
94
+ ]
95
+ for name, ftype in _get_field_types(pb_message).items():
96
+ if ftype == "string":
97
+ cols.append(Column(name, String))
98
+ elif ftype == "int":
99
+ cols.append(Column(name, Integer))
100
+ elif ftype == "binary":
101
+ cols.append(Column(name, LargeBinary))
102
+ elif ftype == "bool":
103
+ cols.append(Column(name, types.Boolean))
104
+ elif ftype == "float":
105
+ cols.append(Column(name, types.Float))
106
+ elif ftype == "message":
107
+ cols.append(
108
+ Column(
109
+ name,
110
+ types.DateTime,
111
+ server_default=func.now(), # pylint: disable=not-callable
112
+ )
113
+ )
114
+ else:
115
+ raise ValueError(f"Unsupported field type: {ftype!r}")
116
+ if ADD_BACKEND_TO_ALL:
117
+ cols.append(Column(BACKEND_ID_VARIABLE, String(256), nullable=False))
118
+
119
+ tbl = Table(tbl_name, metadata, *cols, extend_existing=True)
120
+ metadata.create_all(engine)
121
+ return tbl
122
+
123
+
124
+ class DynamicTableDb:
125
+ def __init__(self, db_name: str) -> None:
126
+ self.db_name = db_name
127
+
128
+ @staticmethod
129
+ def is_in_db(msg: Message, db_name: str, channel: str, attr: str, value: T.Any) -> bool:
130
+ """Static entry-point for existence check."""
131
+ return DynamicTableDb(db_name).inst_is_in_db(msg, channel, attr, value)
132
+
133
+ def inst_is_in_db(
134
+ self,
135
+ message_pb: Message,
136
+ channel_name: str,
137
+ attr: str,
138
+ value: T.Any,
139
+ ) -> bool:
140
+ try:
141
+ tbl = _create_dynamic_table(channel_name, type(message_pb), self.db_name)
142
+ except ValueError as e:
143
+ print(f"is_in_db error: {e}") # explicitly use print() here to avoid circular call
144
+ return False
145
+ if not hasattr(tbl.c, attr):
146
+ raise ValueError(
147
+ f"Column '{attr}' missing in '{channel_name}', cols={list(tbl.c.keys())}"
148
+ )
149
+ with ManagedSession(db=self.db_name) as sess:
150
+ if sess is None:
151
+ return False
152
+ stmt = tbl.select().where(getattr(tbl.c, attr) == value)
153
+ return bool(sess.execute(stmt).fetchone())
154
+
155
+ @staticmethod
156
+ def log_data_to_db(
157
+ msg: Message,
158
+ db_name: str,
159
+ channel: str,
160
+ log_print_failure: bool = True,
161
+ ) -> None:
162
+ """Static entry-point for logging messages."""
163
+ assert db_name.strip(), "db_name is required"
164
+ DynamicTableDb(db_name).add_message(channel, msg, log_print_failure)
165
+
166
+ def add_message(
167
+ self,
168
+ channel_name: str,
169
+ message_pb: Message,
170
+ log_print_failure: bool = True,
171
+ verbose: bool = False,
172
+ ) -> None:
173
+ printer = log.print_fail if log_print_failure else print
174
+ try:
175
+ tbl = _create_dynamic_table(channel_name, type(message_pb), self.db_name)
176
+ except ValueError as e:
177
+ if verbose:
178
+ printer(f"add_message table error: {e}")
179
+ return
180
+ data = self.protobuf_to_dict(message_pb)
181
+ with ManagedSession(db=self.db_name) as sess:
182
+ if sess is None:
183
+ return
184
+ stmt = insert(tbl).values(**data)
185
+ try:
186
+ sess.execute(stmt)
187
+ except Exception as e: # pylint: disable=broad-exception-caught
188
+ if verbose:
189
+ printer(f"insert failed: {e}")
190
+
191
+ def protobuf_to_dict(self, message_pb: Message) -> T.Dict[str, T.Any]:
192
+ out: T.Dict[str, T.Any] = {}
193
+ for fld in message_pb.DESCRIPTOR.fields:
194
+ val = getattr(message_pb, fld.name)
195
+ if fld.type == FieldDescriptor.TYPE_MESSAGE and fld.message_type.name == "Timestamp":
196
+ out[fld.name] = _combine_pb_timestamp(val.seconds, val.nanos)
197
+ else:
198
+ out[fld.name] = val
199
+ return out
File without changes
@@ -0,0 +1,14 @@
1
+ from ry_redis_bus.channels import Channel
2
+
3
+ from ry_pg_utils.pb_types.database_pb2 import DatabaseConfigPb # pylint: disable=no-name-in-module
4
+ from ry_pg_utils.pb_types.database_pb2 import ( # pylint: disable=no-name-in-module
5
+ DatabaseNotificationPb,
6
+ )
7
+ from ry_pg_utils.pb_types.database_pb2 import ( # pylint: disable=no-name-in-module
8
+ DatabaseSettingsPb,
9
+ )
10
+
11
+ # Channels
12
+ DATABASE_CHANNEL = Channel("DATABASE_CHANNEL", DatabaseConfigPb)
13
+ DATABASE_CONFIG_CHANNEL = Channel("DATABASE_CONFIG_CHANNEL", DatabaseSettingsPb)
14
+ DATABASE_NOTIFY_CHANNEL = Channel("DATABASE_NOTIFY_CHANNEL", DatabaseNotificationPb)