sqlnotify 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlnotify/utils.py ADDED
@@ -0,0 +1,165 @@
1
+ from sqlalchemy.engine import Engine
2
+ from sqlalchemy.ext.asyncio import AsyncEngine
3
+
4
+ from .constants import MAX_SQLNOTIFY_IDENTIFER_BYTES, MAX_SQLNOTIFY_PAYLOAD_BYTES
5
+ from .exceptions import SQLNotifyIdentifierSizeError, SQLNotifyPayloadSizeError
6
+
7
+
8
+ def extract_database_url(engine: AsyncEngine | Engine) -> str:
9
+ """
10
+ Extract the database URL from an engine for use with asyncpg LISTEN.
11
+
12
+ Converts SQLAlchemy URL format to plain PostgreSQL URL suitable for asyncpg.
13
+
14
+ Args:
15
+ engine (Union[AsyncEngine, Engine]): SQLAlchemy Engine or AsyncEngine
16
+
17
+ Returns:
18
+ str: PostgreSQL connection URL suitable for asyncpg
19
+ """
20
+
21
+ url = engine.url.render_as_string(hide_password=False)
22
+
23
+ sqlite_drivers = ["+aiosqlite"]
24
+ postgres_drivers = ["+asyncpg", "+aiopg", "+psycopg2", "+psycopg", "+pg8000"]
25
+
26
+ if "sqlite" in url:
27
+ for driver in sorted(sqlite_drivers, key=len, reverse=True):
28
+ url = url.replace(f"sqlite{driver}", "sqlite")
29
+ elif "postgresql" in url:
30
+ for driver in sorted(postgres_drivers, key=len, reverse=True):
31
+ url = url.replace(f"postgresql{driver}", "postgresql")
32
+
33
+ return url
34
+
35
+
36
+ def strip_database_query_params(db_url: str) -> str:
37
+ """
38
+ Strip query parameters from a PostgreSQL connection URL.
39
+
40
+ Args:
41
+ db_url (str): PostgreSQL connection URL
42
+
43
+ Returns:
44
+ str: Connection URL without query parameters
45
+
46
+ """
47
+
48
+ if "?" in db_url:
49
+ db_url = db_url.split("?", 1)[0]
50
+
51
+ return db_url
52
+
53
+
54
+ def replace_spaces_with_underscores(name: str) -> str:
55
+ """
56
+ Replace spaces in a string with underscores.
57
+
58
+ Args:
59
+ name (str): Input string
60
+
61
+ Returns:
62
+ str: String with spaces replaced by underscores
63
+ """
64
+
65
+ return name.strip().replace(" ", "_")
66
+
67
+
68
+ def validate_payload_size(payload_str: str, allow_overflow: bool = False) -> bool:
69
+ """
70
+ Validate payload size doesn't exceed PostgreSQL limit.
71
+
72
+ Args:
73
+ payload_str (str): JSON payload string
74
+ allow_overflow (bool): If True, returns False instead of raising exception
75
+
76
+ Returns:
77
+ bool: True if valid, False if overflow and allow_overflow=True
78
+
79
+ Raises:
80
+ SQLNotifyPayloadSizeError: If payload too large and allow_overflow=False
81
+ """
82
+
83
+ payload_bytes = len(payload_str.encode("utf-8"))
84
+
85
+ if payload_bytes > MAX_SQLNOTIFY_PAYLOAD_BYTES:
86
+ if allow_overflow:
87
+ return False
88
+
89
+ raise SQLNotifyPayloadSizeError(
90
+ f"Payload size {payload_bytes} bytes exceeds SQLNotify "
91
+ f"limit of {MAX_SQLNOTIFY_PAYLOAD_BYTES} bytes. "
92
+ f"Consider: (1) reducing extra_columns, (2) using use_overflow_table=True, "
93
+ f"or (3) sending only ID and fetching full data from database."
94
+ )
95
+
96
+ return True
97
+
98
+
99
+ def hash_identifier(s: str, max_length: int = 20) -> str:
100
+ """
101
+ Hash an identifier string to a shorter unique string using SHA-256.
102
+
103
+ Uses cryptographic hashing to minimize collision probability while keeping
104
+ the output short enough for PostgreSQL identifier limits (63 bytes).
105
+
106
+ Args:
107
+ s (str): Input string to hash
108
+ max_length (int): Maximum length of output string (default: 20 chars = ~50 bytes)
109
+
110
+ Returns:
111
+ str: Hashed string value, significantly shorter than input with very low collision probability
112
+
113
+ Note:
114
+ While collisions are theoretically possible with any hash, SHA-256 truncated to 20 characters
115
+ (80 bits) provides ~1.2e24 possible values, making collisions extremely unlikely in practice.
116
+ """
117
+ import hashlib
118
+
119
+ hash_bytes = hashlib.sha256(s.encode("utf-8")).digest()
120
+
121
+ hex_hash = hash_bytes.hex()[:max_length]
122
+
123
+ return hex_hash
124
+
125
+
126
+ def validate_identifier_size(
127
+ identifiers: list[str],
128
+ stop_on_first_error: bool = True,
129
+ ) -> bool:
130
+ """
131
+ Validate that PostgreSQL identifiers (e.g. channel, function, trigger names) do not exceed the byte limit.
132
+
133
+ Args:
134
+ identifiers (list[str]): The identifiers to validate
135
+ stop_on_first_error (bool): If True, raise on the first offending identifier. If False, collect
136
+ all offending identifiers and raise a combined error message listing them.
137
+
138
+ Returns:
139
+ bool: True if all identifiers are within limits
140
+
141
+ Raises:
142
+ SQLNotifyIdentifierSizeError: If one or more identifiers exceed the PostgreSQL byte limit
143
+ """
144
+
145
+ errors: list[str] = []
146
+
147
+ for identifier in identifiers:
148
+ identifier_bytes = len(identifier.encode("utf-8"))
149
+
150
+ if identifier_bytes > MAX_SQLNOTIFY_IDENTIFER_BYTES:
151
+ msg = (
152
+ f"Identifier '{identifier}' is {identifier_bytes} bytes, "
153
+ f"and exceeds limit of {MAX_SQLNOTIFY_IDENTIFER_BYTES} bytes."
154
+ )
155
+
156
+ if stop_on_first_error:
157
+ raise SQLNotifyIdentifierSizeError(msg)
158
+
159
+ errors.append(msg)
160
+
161
+ if errors:
162
+ combined = "; ".join(errors)
163
+ raise SQLNotifyIdentifierSizeError(combined)
164
+
165
+ return True
sqlnotify/watcher.py ADDED
@@ -0,0 +1,72 @@
1
+ import logging
2
+ from typing import Any
3
+
4
+ from sqlalchemy import inspect as sa_inspect
5
+ from sqlalchemy.orm import Mapper
6
+
7
+ from .constants import PACKAGE_NAME
8
+ from .exceptions import SQLNotifyConfigurationError
9
+ from .types import Operation
10
+ from .utils import hash_identifier, replace_spaces_with_underscores, validate_identifier_size
11
+
12
+
13
+ class Watcher:
14
+ """
15
+ Configuration class for SQLNotify Watcher
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ model: type,
21
+ operation: Operation,
22
+ extra_columns: list[str] | None,
23
+ trigger_columns: list[str] | None,
24
+ primary_keys: list[str],
25
+ channel_label: str | None = None,
26
+ use_overflow_table: bool = False,
27
+ logger: logging.Logger | None = None,
28
+ ):
29
+ self.model = model
30
+ self.operation = operation
31
+ self.extra_columns = extra_columns or []
32
+ self.trigger_columns = trigger_columns
33
+ self.primary_keys = primary_keys
34
+ self.channel_label = channel_label
35
+ self.use_overflow_table = use_overflow_table
36
+ self._logger = logger
37
+
38
+ try:
39
+ sa_mapper: Mapper[Any] = sa_inspect(model)
40
+ self.table_name = (
41
+ sa_mapper.local_table.name # type: ignore[attr-defined]
42
+ or model.__table__.name # type: ignore[attr-defined]
43
+ or model.__tablename__ # type: ignore[attr-defined]
44
+ or model.__name__.lower()
45
+ )
46
+ self.schema_name = sa_mapper.local_table.schema or getattr(model, "__table_args__", {}).get(
47
+ "schema", "public"
48
+ )
49
+ except Exception as e:
50
+ if self._logger:
51
+ self._logger.error(
52
+ f"Error inspecting model '{model.__name__}': {str(e)}",
53
+ exc_info=True,
54
+ )
55
+
56
+ raise SQLNotifyConfigurationError(
57
+ f"Failed to inspect model '{model.__name__}'. Ensure it's a valid SQLAlchemy or SQLModel model"
58
+ ) from e
59
+
60
+ self.channel_name = f"{PACKAGE_NAME}_{self.schema_name}_{self.table_name}_{self.operation.value}"
61
+ if self.channel_label:
62
+ self.channel_name = f"{PACKAGE_NAME}_{replace_spaces_with_underscores(self.channel_label)}"
63
+
64
+ hashed_channel_name = hash_identifier(self.channel_name)
65
+
66
+ self.function_name = f"notify_{hashed_channel_name}"
67
+ self.trigger_name = f"trigger_{hashed_channel_name}"
68
+
69
+ validate_identifier_size([self.channel_name, self.function_name, self.trigger_name])
70
+
71
+ if self.use_overflow_table:
72
+ self.overflow_table_name = f"{PACKAGE_NAME}_overflow"