ry-pg-utils 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ry_pg_utils/__init__.py +0 -0
- ry_pg_utils/config.py +44 -0
- ry_pg_utils/connect.py +288 -0
- ry_pg_utils/dynamic_table.py +199 -0
- ry_pg_utils/ipc/__init__.py +0 -0
- ry_pg_utils/ipc/channels.py +14 -0
- ry_pg_utils/notify_trigger.py +346 -0
- ry_pg_utils/parse_args.py +15 -0
- ry_pg_utils/pb_types/__init__.py +0 -0
- ry_pg_utils/pb_types/database_pb2.py +38 -0
- ry_pg_utils/pb_types/database_pb2.pyi +156 -0
- ry_pg_utils/pb_types/py.typed +0 -0
- ry_pg_utils/postgres_info.py +47 -0
- ry_pg_utils/py.typed +0 -0
- ry_pg_utils/updater.py +181 -0
- ry_pg_utils-1.0.2.dist-info/METADATA +473 -0
- ry_pg_utils-1.0.2.dist-info/RECORD +20 -0
- ry_pg_utils-1.0.2.dist-info/WHEEL +5 -0
- ry_pg_utils-1.0.2.dist-info/licenses/LICENSE +21 -0
- ry_pg_utils-1.0.2.dist-info/top_level.txt +1 -0
ry_pg_utils/__init__.py
ADDED
File without changes
|
ry_pg_utils/config.py
ADDED
@@ -0,0 +1,44 @@
|
|
1
|
+
import os
|
2
|
+
import socket
|
3
|
+
from dataclasses import dataclass
|
4
|
+
|
5
|
+
import dotenv
|
6
|
+
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class Config:
|
10
|
+
postgres_host: str | None
|
11
|
+
postgres_port: int | None
|
12
|
+
postgres_db: str | None
|
13
|
+
postgres_user: str | None
|
14
|
+
postgres_password: str | None
|
15
|
+
do_publish_db: bool
|
16
|
+
use_local_db_only: bool
|
17
|
+
backend_id: str
|
18
|
+
add_backend_to_all: bool
|
19
|
+
add_backend_to_tables: bool
|
20
|
+
raise_on_use_before_init: bool
|
21
|
+
|
22
|
+
|
23
|
+
dotenv.load_dotenv()
|
24
|
+
|
25
|
+
# Parse POSTGRES_PORT with proper None handling for mypy
|
26
|
+
_postgres_port_str = os.getenv("POSTGRES_PORT")
|
27
|
+
_postgres_port = int(_postgres_port_str) if _postgres_port_str is not None else None
|
28
|
+
|
29
|
+
pg_config = Config(
|
30
|
+
postgres_host=os.getenv("POSTGRES_HOST"),
|
31
|
+
postgres_port=_postgres_port,
|
32
|
+
postgres_db=os.getenv("POSTGRES_DB"),
|
33
|
+
postgres_user=os.getenv("POSTGRES_USER"),
|
34
|
+
postgres_password=os.getenv("POSTGRES_PASSWORD"),
|
35
|
+
do_publish_db=False,
|
36
|
+
use_local_db_only=True,
|
37
|
+
backend_id=(
|
38
|
+
os.getenv("POSTGRES_USER")
|
39
|
+
or f"{socket.gethostname()}_{socket.gethostbyname(socket.gethostname())}"
|
40
|
+
),
|
41
|
+
add_backend_to_all=True,
|
42
|
+
add_backend_to_tables=True,
|
43
|
+
raise_on_use_before_init=True,
|
44
|
+
)
|
ry_pg_utils/connect.py
ADDED
@@ -0,0 +1,288 @@
|
|
1
|
+
import contextlib
|
2
|
+
import importlib
|
3
|
+
import pkgutil
|
4
|
+
import threading
|
5
|
+
import typing as T
|
6
|
+
|
7
|
+
from ryutils import log
|
8
|
+
from sqlalchemy import Column, String, create_engine, event
|
9
|
+
from sqlalchemy.engine.base import Engine
|
10
|
+
from sqlalchemy.exc import OperationalError
|
11
|
+
from sqlalchemy.orm import declarative_base, declared_attr, scoped_session, sessionmaker
|
12
|
+
from sqlalchemy.orm.decl_api import DeclarativeMeta
|
13
|
+
from sqlalchemy.orm.scoping import ScopedSession
|
14
|
+
from sqlalchemy_utils import database_exists
|
15
|
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
16
|
+
|
17
|
+
from ry_pg_utils import config
|
18
|
+
|
19
|
+
_thread_local = threading.local()
|
20
|
+
BACKEND_ID_VARIABLE = "backend_id"
|
21
|
+
|
22
|
+
ENGINE: T.Dict[str, Engine] = {}
|
23
|
+
THREAD_SAFE_SESSION_FACTORY: T.Dict[str, ScopedSession] = {}
|
24
|
+
|
25
|
+
# Base class with optional backend_id field
|
26
|
+
if config.pg_config.add_backend_to_all:
|
27
|
+
# Add any common fields here
|
28
|
+
class CommonBaseModel:
|
29
|
+
@declared_attr
|
30
|
+
def backend_id(cls: T.Any) -> Column: # pylint: disable=no-self-argument
|
31
|
+
return Column(String(256), nullable=False)
|
32
|
+
|
33
|
+
Base = declarative_base(name="Base", cls=CommonBaseModel)
|
34
|
+
else:
|
35
|
+
Base = declarative_base(name="Base")
|
36
|
+
|
37
|
+
# Add type annotation for Base
|
38
|
+
Base: DeclarativeMeta # type: ignore
|
39
|
+
|
40
|
+
|
41
|
+
def get_table_name(
|
42
|
+
base_name: str, verbose: bool = False, backend_id: str = config.pg_config.backend_id
|
43
|
+
) -> str:
|
44
|
+
if verbose:
|
45
|
+
print(
|
46
|
+
f"{base_name}_{backend_id}"
|
47
|
+
if config.pg_config.add_backend_to_tables
|
48
|
+
else f"{base_name}"
|
49
|
+
)
|
50
|
+
return f"{base_name}_{backend_id}" if config.pg_config.add_backend_to_tables else base_name
|
51
|
+
|
52
|
+
|
53
|
+
def init_engine(uri: str, db: str, **kwargs: T.Any) -> Engine:
|
54
|
+
global ENGINE # pylint: disable=global-variable-not-assigned
|
55
|
+
if db not in ENGINE:
|
56
|
+
# Add pool settings to automatically recycle connections
|
57
|
+
default_pool_settings = {
|
58
|
+
"pool_recycle": 3600, # Recycle connections after 1 hour
|
59
|
+
"pool_pre_ping": True, # Enable connection health checks
|
60
|
+
"pool_size": 5, # Maintain a pool of connections
|
61
|
+
"max_overflow": 10, # Allow up to 10 additional connections
|
62
|
+
}
|
63
|
+
# Update kwargs with defaults if not already set
|
64
|
+
for key, value in default_pool_settings.items():
|
65
|
+
kwargs.setdefault(key, value)
|
66
|
+
ENGINE[db] = create_engine(uri, **kwargs)
|
67
|
+
return ENGINE[db]
|
68
|
+
|
69
|
+
|
70
|
+
def get_engine(db: str) -> Engine:
|
71
|
+
global ENGINE # pylint: disable=global-variable-not-assigned
|
72
|
+
return ENGINE[db]
|
73
|
+
|
74
|
+
|
75
|
+
def clear_db() -> None:
|
76
|
+
global ENGINE # pylint: disable=global-statement
|
77
|
+
global THREAD_SAFE_SESSION_FACTORY # pylint: disable=global-statement
|
78
|
+
ENGINE = {}
|
79
|
+
THREAD_SAFE_SESSION_FACTORY = {}
|
80
|
+
|
81
|
+
|
82
|
+
def close_engine(db: str) -> None:
|
83
|
+
global ENGINE # pylint: disable=global-statement, global-variable-not-assigned
|
84
|
+
global THREAD_SAFE_SESSION_FACTORY # pylint: disable=global-statement, global-variable-not-assigned
|
85
|
+
if db in ENGINE:
|
86
|
+
ENGINE[db].dispose()
|
87
|
+
del ENGINE[db]
|
88
|
+
if db in THREAD_SAFE_SESSION_FACTORY:
|
89
|
+
del THREAD_SAFE_SESSION_FACTORY[db]
|
90
|
+
|
91
|
+
|
92
|
+
@retry(
|
93
|
+
stop=stop_after_attempt(3),
|
94
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
95
|
+
retry=lambda e: isinstance(e, (OperationalError, TimeoutError)),
|
96
|
+
)
|
97
|
+
def _init_session_factory(db: str) -> ScopedSession:
|
98
|
+
"""Initialize the THREAD_SAFE_SESSION_FACTORY."""
|
99
|
+
global ENGINE, THREAD_SAFE_SESSION_FACTORY # pylint: disable=global-variable-not-assigned
|
100
|
+
if db not in ENGINE:
|
101
|
+
raise ValueError(
|
102
|
+
"Initialize ENGINE by calling init_engine before calling _init_session_factory!"
|
103
|
+
)
|
104
|
+
if db not in THREAD_SAFE_SESSION_FACTORY:
|
105
|
+
session_factory = sessionmaker(bind=ENGINE[db])
|
106
|
+
THREAD_SAFE_SESSION_FACTORY[db] = scoped_session(session_factory)
|
107
|
+
return THREAD_SAFE_SESSION_FACTORY[db]
|
108
|
+
|
109
|
+
|
110
|
+
def set_backend_id(backend_id: str) -> None:
|
111
|
+
setattr(_thread_local, BACKEND_ID_VARIABLE, backend_id)
|
112
|
+
|
113
|
+
|
114
|
+
def get_backend_id() -> T.Optional[str]:
|
115
|
+
return getattr(_thread_local, BACKEND_ID_VARIABLE, None)
|
116
|
+
|
117
|
+
|
118
|
+
@event.listens_for(scoped_session, "before_flush")
|
119
|
+
def receive_before_flush(session: ScopedSession, _flush_context: T.Any, _instances: T.Any) -> None:
|
120
|
+
backend_id = get_backend_id()
|
121
|
+
if not backend_id:
|
122
|
+
return
|
123
|
+
|
124
|
+
# Automatically add backend_id to instances that have it as a field
|
125
|
+
for instance in session.dirty:
|
126
|
+
if hasattr(instance, "backend_id") and instance.backend_id is None:
|
127
|
+
instance.backend_id = backend_id
|
128
|
+
|
129
|
+
for instance in session.new:
|
130
|
+
if hasattr(instance, "backend_id") and instance.backend_id is None:
|
131
|
+
instance.backend_id = backend_id
|
132
|
+
|
133
|
+
|
134
|
+
def is_session_factory_initialized() -> bool:
|
135
|
+
return bool(THREAD_SAFE_SESSION_FACTORY)
|
136
|
+
|
137
|
+
|
138
|
+
@contextlib.contextmanager
|
139
|
+
def ManagedSession( # pylint: disable=invalid-name
|
140
|
+
db: T.Optional[str] = None, backend_id: T.Optional[str] = config.pg_config.backend_id
|
141
|
+
) -> T.Iterator[T.Optional[ScopedSession]]:
|
142
|
+
"""Get a session object whose lifecycle, commits and flush are managed for you.
|
143
|
+
The session will automatically retry operations on connection errors.
|
144
|
+
|
145
|
+
Expected to be used as follows:
|
146
|
+
```
|
147
|
+
# multiple db_operations are done within one session.
|
148
|
+
with ManagedSession() as session:
|
149
|
+
# db_operations is expected not to worry about session handling.
|
150
|
+
db_operations.select(session, **kwargs)
|
151
|
+
# after the with statement, the session commits to the database.
|
152
|
+
db_operations.insert(session, **kwargs)
|
153
|
+
```
|
154
|
+
"""
|
155
|
+
global THREAD_SAFE_SESSION_FACTORY # pylint: disable=global-variable-not-assigned
|
156
|
+
if db is None:
|
157
|
+
# assume we're just using the default db
|
158
|
+
db = list(THREAD_SAFE_SESSION_FACTORY.keys())[0]
|
159
|
+
|
160
|
+
if db not in THREAD_SAFE_SESSION_FACTORY:
|
161
|
+
if config.pg_config.raise_on_use_before_init:
|
162
|
+
raise ValueError(f"Call _init_session_factory for {db} before using ManagedSession!")
|
163
|
+
log.print_fail(f"Call _init_session_factory for {db} before using ManagedSession!")
|
164
|
+
yield None
|
165
|
+
return
|
166
|
+
|
167
|
+
@retry(
|
168
|
+
stop=stop_after_attempt(3),
|
169
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
170
|
+
retry=lambda e: isinstance(e, OperationalError),
|
171
|
+
)
|
172
|
+
def execute_with_retry(session: ScopedSession) -> T.Iterator[ScopedSession]:
|
173
|
+
try:
|
174
|
+
yield session
|
175
|
+
session.commit()
|
176
|
+
session.flush()
|
177
|
+
except Exception:
|
178
|
+
session.rollback()
|
179
|
+
raise
|
180
|
+
|
181
|
+
session = THREAD_SAFE_SESSION_FACTORY[db]()
|
182
|
+
|
183
|
+
if backend_id:
|
184
|
+
set_backend_id(backend_id)
|
185
|
+
|
186
|
+
try:
|
187
|
+
yield from execute_with_retry(session)
|
188
|
+
finally:
|
189
|
+
# source:
|
190
|
+
# https://stackoverflow.com/questions/
|
191
|
+
# 21078696/why-is-my-scoped-session-raising-an-attributeerror-session-object-has-no-attr
|
192
|
+
THREAD_SAFE_SESSION_FACTORY[db].remove()
|
193
|
+
|
194
|
+
|
195
|
+
def is_database_initialized(db: str) -> bool:
|
196
|
+
"""Check if the database is initialized."""
|
197
|
+
global THREAD_SAFE_SESSION_FACTORY # pylint: disable=global-variable-not-assigned
|
198
|
+
return db in THREAD_SAFE_SESSION_FACTORY
|
199
|
+
|
200
|
+
|
201
|
+
def _import_models_from_module(module_path: str) -> None:
|
202
|
+
"""
|
203
|
+
Dynamically import all model classes from a given module path.
|
204
|
+
|
205
|
+
This ensures all SQLAlchemy models are registered with Base.metadata
|
206
|
+
before creating tables.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
module_path: Dot-separated module path (e.g., 'database.models')
|
210
|
+
"""
|
211
|
+
try:
|
212
|
+
# Import the base module
|
213
|
+
base_module = importlib.import_module(module_path)
|
214
|
+
|
215
|
+
# Get the package path
|
216
|
+
if hasattr(base_module, "__path__"):
|
217
|
+
package_path = base_module.__path__
|
218
|
+
else:
|
219
|
+
# It's a single module, not a package
|
220
|
+
return
|
221
|
+
|
222
|
+
# Walk through all submodules
|
223
|
+
for _, modname, _ in pkgutil.walk_packages(
|
224
|
+
path=package_path,
|
225
|
+
prefix=f"{module_path}.",
|
226
|
+
):
|
227
|
+
try:
|
228
|
+
importlib.import_module(modname)
|
229
|
+
except Exception as e: # pylint: disable=broad-except
|
230
|
+
log.print_warn(f"Failed to import {modname}: {e}")
|
231
|
+
continue
|
232
|
+
|
233
|
+
log.print_ok_blue(f"Imported models from {module_path}")
|
234
|
+
except ModuleNotFoundError:
|
235
|
+
log.print_warn(f"Models module not found: {module_path}")
|
236
|
+
except Exception as e: # pylint: disable=broad-except
|
237
|
+
log.print_warn(f"Error importing models from {module_path}: {e}")
|
238
|
+
|
239
|
+
|
240
|
+
def init_database(
|
241
|
+
db_name: str,
|
242
|
+
db_user: str = "",
|
243
|
+
db_password: str = "",
|
244
|
+
db_host: str = "localhost",
|
245
|
+
db_port: int = 5432,
|
246
|
+
models_module: T.Optional[str] = None,
|
247
|
+
) -> None:
|
248
|
+
"""
|
249
|
+
Initialize a database connection and create tables.
|
250
|
+
|
251
|
+
Args:
|
252
|
+
db_name: Name of the database
|
253
|
+
db_user: Database username
|
254
|
+
db_password: Database password
|
255
|
+
db_host: Database host
|
256
|
+
db_port: Database port
|
257
|
+
models_module: Optional dot-separated module path to your models
|
258
|
+
(e.g., 'database.models' or 'myapp.db.models').
|
259
|
+
All model classes in this module will be automatically
|
260
|
+
imported to register them with SQLAlchemy.
|
261
|
+
"""
|
262
|
+
log.print_normal(f"Initializing database {db_name} at {db_host}:{db_port}")
|
263
|
+
|
264
|
+
if db_user and db_password:
|
265
|
+
uri = f"postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}"
|
266
|
+
elif db_user:
|
267
|
+
uri = f"postgresql://{db_user}@{db_host}:{db_port}/{db_name}"
|
268
|
+
else:
|
269
|
+
uri = f"postgresql://{db_host}:{db_port}/{db_name}"
|
270
|
+
|
271
|
+
engine = init_engine(uri, db_name)
|
272
|
+
|
273
|
+
if database_exists(engine.url):
|
274
|
+
log.print_normal("Found existing database")
|
275
|
+
else:
|
276
|
+
log.print_ok_blue("Creating new database!")
|
277
|
+
|
278
|
+
# Import models if module path provided
|
279
|
+
if models_module:
|
280
|
+
_import_models_from_module(models_module)
|
281
|
+
|
282
|
+
try:
|
283
|
+
Base.metadata.create_all(bind=engine)
|
284
|
+
|
285
|
+
_init_session_factory(db_name)
|
286
|
+
except OperationalError as exc:
|
287
|
+
log.print_fail(f"Failed to initialize database: {exc}")
|
288
|
+
log.print_normal("Continuing without db connection...")
|
@@ -0,0 +1,199 @@
|
|
1
|
+
import typing as T
|
2
|
+
from datetime import datetime, timezone
|
3
|
+
|
4
|
+
from google.protobuf.descriptor import FieldDescriptor
|
5
|
+
from google.protobuf.message import Message
|
6
|
+
from ryutils import log
|
7
|
+
from sqlalchemy import (
|
8
|
+
Column,
|
9
|
+
DateTime,
|
10
|
+
Integer,
|
11
|
+
LargeBinary,
|
12
|
+
MetaData,
|
13
|
+
String,
|
14
|
+
Table,
|
15
|
+
func,
|
16
|
+
insert,
|
17
|
+
inspect,
|
18
|
+
types,
|
19
|
+
)
|
20
|
+
from sqlalchemy.engine import Engine
|
21
|
+
from sqlalchemy.exc import OperationalError
|
22
|
+
|
23
|
+
from ry_pg_utils.config import pg_config
|
24
|
+
from ry_pg_utils.connect import (
|
25
|
+
BACKEND_ID_VARIABLE,
|
26
|
+
ENGINE,
|
27
|
+
ManagedSession,
|
28
|
+
get_table_name,
|
29
|
+
init_engine,
|
30
|
+
)
|
31
|
+
|
32
|
+
FIELD_TYPE_MAP: T.Dict[int, str] = {
|
33
|
+
FieldDescriptor.TYPE_DOUBLE: "float",
|
34
|
+
FieldDescriptor.TYPE_FLOAT: "float",
|
35
|
+
FieldDescriptor.TYPE_INT64: "int",
|
36
|
+
FieldDescriptor.TYPE_UINT64: "int",
|
37
|
+
FieldDescriptor.TYPE_INT32: "int",
|
38
|
+
FieldDescriptor.TYPE_FIXED64: "int",
|
39
|
+
FieldDescriptor.TYPE_FIXED32: "int",
|
40
|
+
FieldDescriptor.TYPE_BOOL: "bool",
|
41
|
+
FieldDescriptor.TYPE_STRING: "string",
|
42
|
+
FieldDescriptor.TYPE_GROUP: "string",
|
43
|
+
FieldDescriptor.TYPE_MESSAGE: "message",
|
44
|
+
FieldDescriptor.TYPE_BYTES: "binary",
|
45
|
+
FieldDescriptor.TYPE_UINT32: "int",
|
46
|
+
FieldDescriptor.TYPE_ENUM: "int",
|
47
|
+
FieldDescriptor.TYPE_SFIXED32: "int",
|
48
|
+
FieldDescriptor.TYPE_SFIXED64: "int",
|
49
|
+
FieldDescriptor.TYPE_SINT32: "int",
|
50
|
+
FieldDescriptor.TYPE_SINT64: "int",
|
51
|
+
}
|
52
|
+
|
53
|
+
ADD_BACKEND_TO_ALL = pg_config.add_backend_to_all
|
54
|
+
|
55
|
+
|
56
|
+
def _get_field_types(message_class: T.Type) -> T.Dict[str, str]:
|
57
|
+
"""Retrieve the types of each field in a Protobuf message class."""
|
58
|
+
return {field.name: FIELD_TYPE_MAP[field.type] for field in message_class.DESCRIPTOR.fields}
|
59
|
+
|
60
|
+
|
61
|
+
def _combine_pb_timestamp(seconds: int, nanos: int) -> datetime:
|
62
|
+
"""Combines Protobuf seconds and nanos into a single datetime object."""
|
63
|
+
total_seconds = seconds + nanos / 1e9
|
64
|
+
return datetime.fromtimestamp(total_seconds, tz=timezone.utc)
|
65
|
+
|
66
|
+
|
67
|
+
def _create_dynamic_table(channel_name: str, pb_message: T.Type, db_name: str) -> Table:
|
68
|
+
engine: T.Optional[Engine] = ENGINE.get(db_name)
|
69
|
+
if not engine:
|
70
|
+
raise ValueError(f"Database {db_name!r} not initialized")
|
71
|
+
if not pb_message:
|
72
|
+
raise ValueError(f"Protobuf message {pb_message!r} invalid")
|
73
|
+
|
74
|
+
# attempt inspect
|
75
|
+
try:
|
76
|
+
with engine.connect() as conn:
|
77
|
+
inspector = inspect(conn)
|
78
|
+
existing = inspector.get_table_names()
|
79
|
+
except OperationalError:
|
80
|
+
engine.dispose()
|
81
|
+
engine = init_engine(str(engine.url), db_name)
|
82
|
+
with engine.connect() as conn:
|
83
|
+
inspector = inspect(conn)
|
84
|
+
existing = inspector.get_table_names()
|
85
|
+
|
86
|
+
tbl_name = get_table_name(channel_name)
|
87
|
+
metadata = MetaData()
|
88
|
+
if tbl_name in existing:
|
89
|
+
return Table(tbl_name, metadata, autoload_with=engine)
|
90
|
+
|
91
|
+
cols: T.List[Column] = [
|
92
|
+
Column("key", Integer, primary_key=True, autoincrement=True),
|
93
|
+
Column("created_at", DateTime, server_default=func.now()), # pylint: disable=not-callable
|
94
|
+
]
|
95
|
+
for name, ftype in _get_field_types(pb_message).items():
|
96
|
+
if ftype == "string":
|
97
|
+
cols.append(Column(name, String))
|
98
|
+
elif ftype == "int":
|
99
|
+
cols.append(Column(name, Integer))
|
100
|
+
elif ftype == "binary":
|
101
|
+
cols.append(Column(name, LargeBinary))
|
102
|
+
elif ftype == "bool":
|
103
|
+
cols.append(Column(name, types.Boolean))
|
104
|
+
elif ftype == "float":
|
105
|
+
cols.append(Column(name, types.Float))
|
106
|
+
elif ftype == "message":
|
107
|
+
cols.append(
|
108
|
+
Column(
|
109
|
+
name,
|
110
|
+
types.DateTime,
|
111
|
+
server_default=func.now(), # pylint: disable=not-callable
|
112
|
+
)
|
113
|
+
)
|
114
|
+
else:
|
115
|
+
raise ValueError(f"Unsupported field type: {ftype!r}")
|
116
|
+
if ADD_BACKEND_TO_ALL:
|
117
|
+
cols.append(Column(BACKEND_ID_VARIABLE, String(256), nullable=False))
|
118
|
+
|
119
|
+
tbl = Table(tbl_name, metadata, *cols, extend_existing=True)
|
120
|
+
metadata.create_all(engine)
|
121
|
+
return tbl
|
122
|
+
|
123
|
+
|
124
|
+
class DynamicTableDb:
|
125
|
+
def __init__(self, db_name: str) -> None:
|
126
|
+
self.db_name = db_name
|
127
|
+
|
128
|
+
@staticmethod
|
129
|
+
def is_in_db(msg: Message, db_name: str, channel: str, attr: str, value: T.Any) -> bool:
|
130
|
+
"""Static entry-point for existence check."""
|
131
|
+
return DynamicTableDb(db_name).inst_is_in_db(msg, channel, attr, value)
|
132
|
+
|
133
|
+
def inst_is_in_db(
|
134
|
+
self,
|
135
|
+
message_pb: Message,
|
136
|
+
channel_name: str,
|
137
|
+
attr: str,
|
138
|
+
value: T.Any,
|
139
|
+
) -> bool:
|
140
|
+
try:
|
141
|
+
tbl = _create_dynamic_table(channel_name, type(message_pb), self.db_name)
|
142
|
+
except ValueError as e:
|
143
|
+
print(f"is_in_db error: {e}") # explicitly use print() here to avoid circular call
|
144
|
+
return False
|
145
|
+
if not hasattr(tbl.c, attr):
|
146
|
+
raise ValueError(
|
147
|
+
f"Column '{attr}' missing in '{channel_name}', cols={list(tbl.c.keys())}"
|
148
|
+
)
|
149
|
+
with ManagedSession(db=self.db_name) as sess:
|
150
|
+
if sess is None:
|
151
|
+
return False
|
152
|
+
stmt = tbl.select().where(getattr(tbl.c, attr) == value)
|
153
|
+
return bool(sess.execute(stmt).fetchone())
|
154
|
+
|
155
|
+
@staticmethod
|
156
|
+
def log_data_to_db(
|
157
|
+
msg: Message,
|
158
|
+
db_name: str,
|
159
|
+
channel: str,
|
160
|
+
log_print_failure: bool = True,
|
161
|
+
) -> None:
|
162
|
+
"""Static entry-point for logging messages."""
|
163
|
+
assert db_name.strip(), "db_name is required"
|
164
|
+
DynamicTableDb(db_name).add_message(channel, msg, log_print_failure)
|
165
|
+
|
166
|
+
def add_message(
|
167
|
+
self,
|
168
|
+
channel_name: str,
|
169
|
+
message_pb: Message,
|
170
|
+
log_print_failure: bool = True,
|
171
|
+
verbose: bool = False,
|
172
|
+
) -> None:
|
173
|
+
printer = log.print_fail if log_print_failure else print
|
174
|
+
try:
|
175
|
+
tbl = _create_dynamic_table(channel_name, type(message_pb), self.db_name)
|
176
|
+
except ValueError as e:
|
177
|
+
if verbose:
|
178
|
+
printer(f"add_message table error: {e}")
|
179
|
+
return
|
180
|
+
data = self.protobuf_to_dict(message_pb)
|
181
|
+
with ManagedSession(db=self.db_name) as sess:
|
182
|
+
if sess is None:
|
183
|
+
return
|
184
|
+
stmt = insert(tbl).values(**data)
|
185
|
+
try:
|
186
|
+
sess.execute(stmt)
|
187
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
188
|
+
if verbose:
|
189
|
+
printer(f"insert failed: {e}")
|
190
|
+
|
191
|
+
def protobuf_to_dict(self, message_pb: Message) -> T.Dict[str, T.Any]:
|
192
|
+
out: T.Dict[str, T.Any] = {}
|
193
|
+
for fld in message_pb.DESCRIPTOR.fields:
|
194
|
+
val = getattr(message_pb, fld.name)
|
195
|
+
if fld.type == FieldDescriptor.TYPE_MESSAGE and fld.message_type.name == "Timestamp":
|
196
|
+
out[fld.name] = _combine_pb_timestamp(val.seconds, val.nanos)
|
197
|
+
else:
|
198
|
+
out[fld.name] = val
|
199
|
+
return out
|
File without changes
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from ry_redis_bus.channels import Channel
|
2
|
+
|
3
|
+
from ry_pg_utils.pb_types.database_pb2 import DatabaseConfigPb # pylint: disable=no-name-in-module
|
4
|
+
from ry_pg_utils.pb_types.database_pb2 import ( # pylint: disable=no-name-in-module
|
5
|
+
DatabaseNotificationPb,
|
6
|
+
)
|
7
|
+
from ry_pg_utils.pb_types.database_pb2 import ( # pylint: disable=no-name-in-module
|
8
|
+
DatabaseSettingsPb,
|
9
|
+
)
|
10
|
+
|
11
|
+
# Channels
|
12
|
+
DATABASE_CHANNEL = Channel("DATABASE_CHANNEL", DatabaseConfigPb)
|
13
|
+
DATABASE_CONFIG_CHANNEL = Channel("DATABASE_CONFIG_CHANNEL", DatabaseSettingsPb)
|
14
|
+
DATABASE_NOTIFY_CHANNEL = Channel("DATABASE_NOTIFY_CHANNEL", DatabaseNotificationPb)
|