maleo-database 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maleo/database/__init__.py +0 -0
- maleo/database/config/__init__.py +105 -0
- maleo/database/config/additional.py +36 -0
- maleo/database/config/connection.py +544 -0
- maleo/database/config/identifier.py +12 -0
- maleo/database/config/pooling.py +255 -0
- maleo/database/enums.py +85 -0
- maleo/database/managers/__init__.py +0 -0
- maleo/database/managers/clients/__init__.py +0 -0
- maleo/database/managers/clients/elasticsearch.py +66 -0
- maleo/database/managers/clients/mongodb.py +53 -0
- maleo/database/managers/clients/redis.py +57 -0
- maleo/database/managers/engines/__init__.py +0 -0
- maleo/database/managers/engines/mysql.py +56 -0
- maleo/database/managers/engines/postgresql.py +58 -0
- maleo/database/managers/engines/sqlite.py +55 -0
- maleo/database/managers/engines/sqlserver.py +63 -0
- maleo/database/managers/session.py +123 -0
- maleo/database/orm/__init__.py +0 -0
- maleo/database/orm/base.py +7 -0
- maleo/database/orm/models/__init__.py +0 -0
- maleo/database/orm/models/mixins/__init__.py +0 -0
- maleo/database/orm/models/mixins/identifier.py +11 -0
- maleo/database/orm/models/mixins/status.py +12 -0
- maleo/database/orm/models/mixins/timestamp.py +65 -0
- maleo/database/orm/models/table.py +17 -0
- maleo/database/orm/queries.py +234 -0
- maleo_database-0.0.1.dist-info/METADATA +93 -0
- maleo_database-0.0.1.dist-info/RECORD +32 -0
- maleo_database-0.0.1.dist-info/WHEEL +5 -0
- maleo_database-0.0.1.dist-info/licenses/LICENSE +57 -0
- maleo_database-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,12 @@
|
|
1
|
+
from pydantic import BaseModel, Field
|
2
|
+
from maleo.enums.environment import Environment
|
3
|
+
from maleo.types.base.dict import OptionalStringToStringDict
|
4
|
+
from maleo.types.base.string import OptionalString
|
5
|
+
|
6
|
+
|
7
|
+
class DatabaseIdentifierConfig(BaseModel):
|
8
|
+
enabled: bool = Field(True, description="Whether the database is enabled")
|
9
|
+
environment: Environment = Field(..., description="Database's environment")
|
10
|
+
name: str = Field(..., description="Database's name")
|
11
|
+
description: OptionalString = Field(None, description="Database's description")
|
12
|
+
tags: OptionalStringToStringDict = Field(None, description="Database's tags")
|
@@ -0,0 +1,255 @@
|
|
1
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
2
|
+
from typing import Self, TypeVar
|
3
|
+
from maleo.types.base.dict import OptionalStringToStringDict
|
4
|
+
from maleo.types.base.integer import ListOfIntegers
|
5
|
+
from maleo.types.base.string import OptionalString
|
6
|
+
from maleo.utils.formatters.case import to_camel
|
7
|
+
from ..enums import PoolingStrategy
|
8
|
+
|
9
|
+
|
10
|
+
class BasePoolingConfig(BaseModel):
|
11
|
+
"""Base configuration class for database connection pooling."""
|
12
|
+
|
13
|
+
|
14
|
+
PoolingConfigT = TypeVar("PoolingConfigT", bound=BasePoolingConfig)
|
15
|
+
|
16
|
+
|
17
|
+
class PostgreSQLPoolingConfig(BasePoolingConfig):
|
18
|
+
"""PostgreSQL-specific pooling configuration."""
|
19
|
+
|
20
|
+
pool_size: int = Field(
|
21
|
+
default=10, ge=1, le=1000, description="Number of connections in the pool"
|
22
|
+
)
|
23
|
+
max_overflow: int = Field(
|
24
|
+
default=20, ge=0, le=500, description="Maximum number of overflow connections"
|
25
|
+
)
|
26
|
+
pool_timeout: float = Field(
|
27
|
+
default=30.0,
|
28
|
+
ge=1.0,
|
29
|
+
le=300.0,
|
30
|
+
description="Timeout in seconds for getting connection",
|
31
|
+
)
|
32
|
+
pool_recycle: int = Field(
|
33
|
+
default=3600, ge=60, le=86400, description="Connection recycle time in seconds"
|
34
|
+
)
|
35
|
+
pool_pre_ping: bool = Field(
|
36
|
+
default=True, description="Validate connections before use"
|
37
|
+
)
|
38
|
+
# Keep strategy and prepared_statement_cache_size as they're pooling-related
|
39
|
+
strategy: PoolingStrategy = Field(
|
40
|
+
default=PoolingStrategy.DYNAMIC, description="Pooling strategy"
|
41
|
+
)
|
42
|
+
prepared_statement_cache_size: int = Field(
|
43
|
+
default=100, ge=0, le=10000, description="Prepared statement cache size"
|
44
|
+
)
|
45
|
+
pool_reset_on_return: bool = Field(
|
46
|
+
default=True, description="Reset connection state on return to pool"
|
47
|
+
)
|
48
|
+
|
49
|
+
@model_validator(mode="after")
|
50
|
+
def validate_overflow(self) -> Self:
|
51
|
+
if self.max_overflow > self.pool_size * 5:
|
52
|
+
raise ValueError("max_overflow should not exceed 5x pool_size")
|
53
|
+
return self
|
54
|
+
|
55
|
+
|
56
|
+
class MySQLPoolingConfig(BasePoolingConfig):
|
57
|
+
"""MySQL-specific pooling configuration."""
|
58
|
+
|
59
|
+
pool_size: int = Field(
|
60
|
+
default=8, ge=1, le=500, description="Number of connections in the pool"
|
61
|
+
)
|
62
|
+
max_overflow: int = Field(
|
63
|
+
default=15, ge=0, le=200, description="Maximum number of overflow connections"
|
64
|
+
)
|
65
|
+
pool_timeout: float = Field(
|
66
|
+
default=20.0,
|
67
|
+
ge=1.0,
|
68
|
+
le=300.0,
|
69
|
+
description="Timeout in seconds for getting connection",
|
70
|
+
)
|
71
|
+
pool_recycle: int = Field(
|
72
|
+
default=7200, ge=60, le=86400, description="Connection recycle time in seconds"
|
73
|
+
)
|
74
|
+
pool_pre_ping: bool = Field(
|
75
|
+
default=True, description="Validate connections before use"
|
76
|
+
)
|
77
|
+
strategy: PoolingStrategy = Field(
|
78
|
+
default=PoolingStrategy.FIXED, description="Pooling strategy"
|
79
|
+
)
|
80
|
+
# Add autocommit to pooling since it affects connection behavior in the pool
|
81
|
+
autocommit: bool = Field(default=False, description="Enable autocommit mode")
|
82
|
+
# Move connect_timeout here since it's about pool connection establishment
|
83
|
+
connect_timeout: float = Field(
|
84
|
+
default=10.0, ge=1.0, le=60.0, description="Connection timeout in seconds"
|
85
|
+
)
|
86
|
+
|
87
|
+
|
88
|
+
class SQLitePoolingConfig(BasePoolingConfig):
|
89
|
+
"""SQLite-specific pooling configuration."""
|
90
|
+
|
91
|
+
pool_size: int = Field(
|
92
|
+
default=1, ge=1, le=10, description="Number of connections (limited for SQLite)"
|
93
|
+
)
|
94
|
+
max_overflow: int = Field(
|
95
|
+
default=5, ge=0, le=20, description="Maximum overflow connections"
|
96
|
+
)
|
97
|
+
pool_timeout: float = Field(
|
98
|
+
default=30.0, ge=1.0, le=300.0, description="Timeout in seconds"
|
99
|
+
)
|
100
|
+
# SQLite-specific pooling options
|
101
|
+
wal_mode: bool = Field(
|
102
|
+
default=True, description="Enable WAL mode for better concurrency"
|
103
|
+
)
|
104
|
+
busy_timeout: int = Field(
|
105
|
+
default=30000, ge=1000, le=300000, description="Busy timeout in milliseconds"
|
106
|
+
)
|
107
|
+
|
108
|
+
|
109
|
+
class SQLServerPoolingConfig(BasePoolingConfig):
|
110
|
+
"""SQL Server-specific pooling configuration."""
|
111
|
+
|
112
|
+
pool_size: int = Field(
|
113
|
+
default=10, ge=1, le=500, description="Number of connections in the pool"
|
114
|
+
)
|
115
|
+
max_overflow: int = Field(
|
116
|
+
default=20, ge=0, le=200, description="Maximum number of overflow connections"
|
117
|
+
)
|
118
|
+
pool_timeout: float = Field(
|
119
|
+
default=30.0,
|
120
|
+
ge=1.0,
|
121
|
+
le=300.0,
|
122
|
+
description="Timeout in seconds for getting connection",
|
123
|
+
)
|
124
|
+
pool_recycle: int = Field(
|
125
|
+
default=3600, ge=60, le=86400, description="Connection recycle time in seconds"
|
126
|
+
)
|
127
|
+
pool_pre_ping: bool = Field(
|
128
|
+
default=True, description="Validate connections before use"
|
129
|
+
)
|
130
|
+
strategy: PoolingStrategy = Field(
|
131
|
+
default=PoolingStrategy.DYNAMIC, description="Pooling strategy"
|
132
|
+
)
|
133
|
+
# SQL Server-specific pooling settings
|
134
|
+
connection_timeout: int = Field(
|
135
|
+
default=30, ge=1, le=300, description="Connection timeout in seconds"
|
136
|
+
)
|
137
|
+
command_timeout: int = Field(
|
138
|
+
default=30, ge=1, le=3600, description="Command timeout in seconds"
|
139
|
+
)
|
140
|
+
packet_size: int = Field(
|
141
|
+
default=4096, ge=512, le=32767, description="Network packet size"
|
142
|
+
)
|
143
|
+
trust_server_certificate: bool = Field(
|
144
|
+
default=False, description="Trust server certificate"
|
145
|
+
)
|
146
|
+
# Move encrypt here since it affects connection pool behavior
|
147
|
+
encrypt: bool = Field(default=True, description="Encrypt connection")
|
148
|
+
|
149
|
+
@model_validator(mode="after")
|
150
|
+
def validate_overflow(self) -> Self:
|
151
|
+
if self.max_overflow > self.pool_size * 3:
|
152
|
+
raise ValueError("max_overflow should not exceed 3x pool_size")
|
153
|
+
return self
|
154
|
+
|
155
|
+
|
156
|
+
class MongoDBPoolingConfig(BasePoolingConfig):
|
157
|
+
"""MongoDB-specific pooling configuration."""
|
158
|
+
|
159
|
+
model_config = ConfigDict(alias_generator=to_camel)
|
160
|
+
|
161
|
+
max_pool_size: int = Field(
|
162
|
+
default=100, ge=1, le=500, description="Maximum number of connections in pool"
|
163
|
+
)
|
164
|
+
min_pool_size: int = Field(
|
165
|
+
default=0, ge=0, le=100, description="Minimum number of connections in pool"
|
166
|
+
)
|
167
|
+
max_idle_time_ms: int = Field(
|
168
|
+
default=600000, ge=1000, le=3600000, description="Max idle time in milliseconds"
|
169
|
+
)
|
170
|
+
connect_timeout_ms: int = Field(
|
171
|
+
default=20000,
|
172
|
+
ge=1000,
|
173
|
+
le=300000,
|
174
|
+
description="Connection timeout in milliseconds",
|
175
|
+
)
|
176
|
+
server_selection_timeout_ms: int = Field(
|
177
|
+
default=30000, ge=1000, le=300000, description="Server selection timeout"
|
178
|
+
)
|
179
|
+
max_connecting: int = Field(
|
180
|
+
default=2,
|
181
|
+
ge=1,
|
182
|
+
le=10,
|
183
|
+
description="Maximum number of concurrent connection attempts",
|
184
|
+
)
|
185
|
+
|
186
|
+
|
187
|
+
class RedisPoolingConfig(BasePoolingConfig):
|
188
|
+
"""Redis-specific pooling configuration."""
|
189
|
+
|
190
|
+
max_connections: int = Field(
|
191
|
+
default=50, ge=1, le=1000, description="Maximum number of connections in pool"
|
192
|
+
)
|
193
|
+
retry_on_timeout: bool = Field(
|
194
|
+
default=True, description="Retry on connection timeout"
|
195
|
+
)
|
196
|
+
health_check_interval: int = Field(
|
197
|
+
default=30, ge=5, le=300, description="Health check interval in seconds"
|
198
|
+
)
|
199
|
+
connection_timeout: float = Field(
|
200
|
+
default=5.0, ge=1.0, le=60.0, description="Connection timeout in seconds"
|
201
|
+
)
|
202
|
+
socket_timeout: float = Field(
|
203
|
+
default=5.0, ge=1.0, le=60.0, description="Socket timeout in seconds"
|
204
|
+
)
|
205
|
+
socket_keepalive: bool = Field(default=True, description="Enable TCP keepalive")
|
206
|
+
decode_responses: bool = Field(
|
207
|
+
default=True, description="Decode responses to strings"
|
208
|
+
)
|
209
|
+
|
210
|
+
|
211
|
+
class ElasticsearchPoolingConfig(BasePoolingConfig):
|
212
|
+
"""Elasticsearch-specific pooling configuration."""
|
213
|
+
|
214
|
+
# Connection pool settings
|
215
|
+
maxsize: int = Field(
|
216
|
+
default=25, ge=1, le=100, description="Maximum number of connections in pool"
|
217
|
+
)
|
218
|
+
connections_per_node: int = Field(
|
219
|
+
default=10, ge=1, le=50, description="Connections per Elasticsearch node"
|
220
|
+
)
|
221
|
+
|
222
|
+
# Timeout settings
|
223
|
+
timeout: float = Field(
|
224
|
+
default=10.0, ge=1.0, le=300.0, description="Request timeout in seconds"
|
225
|
+
)
|
226
|
+
max_retries: int = Field(
|
227
|
+
default=3, ge=0, le=10, description="Maximum number of retries"
|
228
|
+
)
|
229
|
+
retry_on_timeout: bool = Field(default=False, description="Retry on timeout")
|
230
|
+
retry_on_status: ListOfIntegers = Field(
|
231
|
+
default_factory=lambda: [502, 503, 504],
|
232
|
+
description="HTTP status codes to retry on",
|
233
|
+
)
|
234
|
+
|
235
|
+
# Connection behavior (move from connection config)
|
236
|
+
http_compress: bool = Field(default=True, description="Enable HTTP compression")
|
237
|
+
verify_certs: bool = Field(default=True, description="Verify SSL certificates")
|
238
|
+
ca_certs: OptionalString = Field(
|
239
|
+
default=None, description="Path to CA certificates"
|
240
|
+
)
|
241
|
+
|
242
|
+
# Advanced pool settings
|
243
|
+
block: bool = Field(default=False, description="Block when pool is full")
|
244
|
+
headers: OptionalStringToStringDict = Field(
|
245
|
+
default=None, description="Default headers for requests"
|
246
|
+
)
|
247
|
+
dead_timeout: float = Field(
|
248
|
+
default=60.0, ge=5.0, le=600.0, description="Dead node timeout in seconds"
|
249
|
+
)
|
250
|
+
|
251
|
+
@model_validator(mode="after")
|
252
|
+
def validate_overflow(self) -> Self:
|
253
|
+
if self.connections_per_node > self.maxsize:
|
254
|
+
raise ValueError("connections_per_node must not exceed maxsize")
|
255
|
+
return self
|
maleo/database/enums.py
ADDED
@@ -0,0 +1,85 @@
|
|
1
|
+
from enum import StrEnum
|
2
|
+
|
3
|
+
|
4
|
+
class Connection(StrEnum):
|
5
|
+
ASYNC = "async"
|
6
|
+
SYNC = "sync"
|
7
|
+
|
8
|
+
|
9
|
+
class Driver(StrEnum):
|
10
|
+
# SQL Databases - Most Popular
|
11
|
+
POSTGRESQL = "postgresql"
|
12
|
+
MYSQL = "mysql"
|
13
|
+
SQLITE = "sqlite"
|
14
|
+
|
15
|
+
# SQL Databases - Enterprise
|
16
|
+
# ORACLE = "oracle"
|
17
|
+
MSSQL = "mssql" # SQL Server
|
18
|
+
# MARIADB = "mariadb"
|
19
|
+
|
20
|
+
# NoSQL Document Stores
|
21
|
+
MONGODB = "mongodb"
|
22
|
+
# COUCHDB = "couchdb"
|
23
|
+
|
24
|
+
# NoSQL Key-Value
|
25
|
+
REDIS = "redis"
|
26
|
+
# DYNAMODB = "dynamodb" # AWS
|
27
|
+
|
28
|
+
# NoSQL Column Family
|
29
|
+
# CASSANDRA = "cassandra"
|
30
|
+
# HBASE = "hbase"
|
31
|
+
|
32
|
+
# NoSQL Graph
|
33
|
+
# NEO4J = "neo4j"
|
34
|
+
# ARANGODB = "arangodb"
|
35
|
+
|
36
|
+
# Time Series
|
37
|
+
# INFLUXDB = "influxdb"
|
38
|
+
# TIMESCALEDB = "timescaledb" # PostgreSQL extension
|
39
|
+
|
40
|
+
# In-Memory
|
41
|
+
# MEMCACHED = "memcached"
|
42
|
+
|
43
|
+
# Search Engines
|
44
|
+
ELASTICSEARCH = "elasticsearch"
|
45
|
+
# OPENSEARCH = "opensearch"
|
46
|
+
|
47
|
+
# Cloud Native
|
48
|
+
# FIRESTORE = "firestore" # Google
|
49
|
+
# COSMOSDB = "cosmosdb" # Azure
|
50
|
+
|
51
|
+
|
52
|
+
class PostgreSQLSSLMode(StrEnum):
|
53
|
+
DISABLE = "disable"
|
54
|
+
ALLOW = "allow"
|
55
|
+
PREFER = "prefer"
|
56
|
+
REQUIRE = "require"
|
57
|
+
VERIFY_CA = "verify-ca"
|
58
|
+
VERIFY_FULL = "verify-full"
|
59
|
+
|
60
|
+
|
61
|
+
class MySQLCharset(StrEnum):
|
62
|
+
UTF8 = "utf8"
|
63
|
+
UTF8MB4 = "utf8mb4"
|
64
|
+
LATIN1 = "latin1"
|
65
|
+
ASCII = "ascii"
|
66
|
+
|
67
|
+
|
68
|
+
class MongoDBReadPreference(StrEnum):
|
69
|
+
PRIMARY = "primary"
|
70
|
+
PRIMARY_PREFERRED = "primaryPreferred"
|
71
|
+
SECONDARY = "secondary"
|
72
|
+
SECONDARY_PREFERRED = "secondaryPreferred"
|
73
|
+
NEAREST = "nearest"
|
74
|
+
|
75
|
+
|
76
|
+
class ElasticsearchScheme(StrEnum):
|
77
|
+
HTTP = "http"
|
78
|
+
HTTPS = "https"
|
79
|
+
|
80
|
+
|
81
|
+
class PoolingStrategy(StrEnum):
|
82
|
+
FIXED = "fixed"
|
83
|
+
DYNAMIC = "dynamic"
|
84
|
+
OVERFLOW = "overflow"
|
85
|
+
QUEUE = "queue"
|
File without changes
|
File without changes
|
@@ -0,0 +1,66 @@
|
|
1
|
+
from elasticsearch import AsyncElasticsearch, Elasticsearch
|
2
|
+
from typing import Literal, Union, overload
|
3
|
+
from ...config import ElasticsearchDatabaseConfig
|
4
|
+
from ...enums import Connection
|
5
|
+
|
6
|
+
|
7
|
+
class ElasticsearchClientManager:
|
8
|
+
def __init__(self, config: ElasticsearchDatabaseConfig) -> None:
|
9
|
+
self.config = config
|
10
|
+
self._async_client: AsyncElasticsearch = self._init(Connection.ASYNC)
|
11
|
+
self._sync_client: Elasticsearch = self._init(Connection.SYNC)
|
12
|
+
|
13
|
+
@overload
|
14
|
+
def _init(self, connection: Literal[Connection.ASYNC]) -> AsyncElasticsearch: ...
|
15
|
+
@overload
|
16
|
+
def _init(self, connection: Literal[Connection.SYNC]) -> Elasticsearch: ...
|
17
|
+
def _init(
|
18
|
+
self, connection: Connection = Connection.ASYNC
|
19
|
+
) -> Union[AsyncElasticsearch, Elasticsearch]:
|
20
|
+
hosts = [
|
21
|
+
{"host": self.config.connection.host, "port": self.config.connection.port}
|
22
|
+
]
|
23
|
+
|
24
|
+
# Build auth and pooling params properly
|
25
|
+
client_kwargs = {}
|
26
|
+
|
27
|
+
if self.config.connection.username and self.config.connection.password:
|
28
|
+
client_kwargs["http_auth"] = (
|
29
|
+
self.config.connection.username,
|
30
|
+
self.config.connection.password,
|
31
|
+
)
|
32
|
+
|
33
|
+
# Use pooling config directly
|
34
|
+
pooling_kwargs = self.config.pooling.model_dump(
|
35
|
+
exclude={
|
36
|
+
"connections_per_node",
|
37
|
+
"block",
|
38
|
+
"headers",
|
39
|
+
"dead_timeout",
|
40
|
+
}, # ES-specific excludes
|
41
|
+
exclude_none=True,
|
42
|
+
)
|
43
|
+
client_kwargs.update(pooling_kwargs)
|
44
|
+
|
45
|
+
if connection is Connection.ASYNC:
|
46
|
+
self._async_client = AsyncElasticsearch(hosts, **client_kwargs)
|
47
|
+
return self._async_client
|
48
|
+
else:
|
49
|
+
self._sync_client = Elasticsearch(hosts, **client_kwargs)
|
50
|
+
return self._sync_client
|
51
|
+
|
52
|
+
@overload
|
53
|
+
def get(self, connection: Literal[Connection.ASYNC]) -> AsyncElasticsearch: ...
|
54
|
+
@overload
|
55
|
+
def get(self, connection: Literal[Connection.SYNC]) -> Elasticsearch: ...
|
56
|
+
def get(
|
57
|
+
self, connection: Connection = Connection.ASYNC
|
58
|
+
) -> Union[AsyncElasticsearch, Elasticsearch]:
|
59
|
+
if connection is Connection.ASYNC:
|
60
|
+
return self._async_client or self._init(Connection.ASYNC)
|
61
|
+
else:
|
62
|
+
return self._sync_client or self._init(Connection.SYNC)
|
63
|
+
|
64
|
+
async def dispose(self):
|
65
|
+
await self._async_client.close()
|
66
|
+
self._sync_client.close()
|
@@ -0,0 +1,53 @@
|
|
1
|
+
from motor.motor_asyncio import AsyncIOMotorClient
|
2
|
+
from pymongo import MongoClient
|
3
|
+
from typing import Literal, Union, overload
|
4
|
+
from ...config import MongoDBDatabaseConfig
|
5
|
+
from ...enums import Connection
|
6
|
+
|
7
|
+
|
8
|
+
class MongoDBClientManager:
|
9
|
+
def __init__(self, config: MongoDBDatabaseConfig) -> None:
|
10
|
+
self.config = config
|
11
|
+
self._async_client: AsyncIOMotorClient = self._init(Connection.ASYNC)
|
12
|
+
self._sync_client: MongoClient = self._init(Connection.SYNC)
|
13
|
+
|
14
|
+
@overload
|
15
|
+
def _init(self, connection: Literal[Connection.ASYNC]) -> AsyncIOMotorClient: ...
|
16
|
+
@overload
|
17
|
+
def _init(self, connection: Literal[Connection.SYNC]) -> MongoClient: ...
|
18
|
+
def _init(
|
19
|
+
self, connection: Connection = Connection.ASYNC
|
20
|
+
) -> Union[AsyncIOMotorClient, MongoClient]:
|
21
|
+
url = self.config.connection.make_url(connection)
|
22
|
+
|
23
|
+
pooling_kwargs = self.config.pooling.model_dump(
|
24
|
+
by_alias=True, exclude_none=True
|
25
|
+
)
|
26
|
+
|
27
|
+
if connection is Connection.ASYNC:
|
28
|
+
self._async_client = AsyncIOMotorClient(url, **pooling_kwargs)
|
29
|
+
return self._async_client
|
30
|
+
else:
|
31
|
+
self._sync_client = MongoClient(url, **pooling_kwargs)
|
32
|
+
return self._sync_client
|
33
|
+
|
34
|
+
@overload
|
35
|
+
def get(self, connection: Literal[Connection.ASYNC]) -> AsyncIOMotorClient: ...
|
36
|
+
@overload
|
37
|
+
def get(self, connection: Literal[Connection.SYNC]) -> MongoClient: ...
|
38
|
+
def get(
|
39
|
+
self, connection: Connection = Connection.ASYNC
|
40
|
+
) -> Union[AsyncIOMotorClient, MongoClient]:
|
41
|
+
if connection is Connection.ASYNC:
|
42
|
+
return self._async_client
|
43
|
+
elif connection is Connection.SYNC:
|
44
|
+
return self._sync_client
|
45
|
+
|
46
|
+
def get_database(self, connection: Connection = Connection.ASYNC):
|
47
|
+
"""Get the specific database object."""
|
48
|
+
client = self.get(connection)
|
49
|
+
return client[self.config.connection.database]
|
50
|
+
|
51
|
+
async def dispose(self):
|
52
|
+
self._async_client.close()
|
53
|
+
self._sync_client.close()
|
@@ -0,0 +1,57 @@
|
|
1
|
+
from redis.asyncio import Redis as AsyncRedis
|
2
|
+
from redis import Redis as SyncRedis
|
3
|
+
from typing import Literal, Union, overload
|
4
|
+
from ...config import RedisDatabaseConfig
|
5
|
+
from ...enums import Connection
|
6
|
+
|
7
|
+
|
8
|
+
class RedisClientManager:
|
9
|
+
def __init__(self, config: RedisDatabaseConfig) -> None:
|
10
|
+
self.config = config
|
11
|
+
self._async_client: AsyncRedis = self._init(Connection.ASYNC)
|
12
|
+
self._sync_client: SyncRedis = self._init(Connection.SYNC)
|
13
|
+
|
14
|
+
@overload
|
15
|
+
def _init(self, connection: Literal[Connection.ASYNC]) -> AsyncRedis: ...
|
16
|
+
@overload
|
17
|
+
def _init(self, connection: Literal[Connection.SYNC]) -> SyncRedis: ...
|
18
|
+
def _init(
|
19
|
+
self, connection: Connection = Connection.ASYNC
|
20
|
+
) -> Union[AsyncRedis, SyncRedis]:
|
21
|
+
url = self.config.connection.make_url(connection)
|
22
|
+
|
23
|
+
# Redis clients expect different parameter names
|
24
|
+
pooling_config = self.config.pooling.model_dump(exclude_none=True)
|
25
|
+
redis_kwargs = {
|
26
|
+
"max_connections": pooling_config.get("max_connections"),
|
27
|
+
"retry_on_timeout": pooling_config.get("retry_on_timeout"),
|
28
|
+
"connection_timeout": pooling_config.get("connection_timeout"),
|
29
|
+
"socket_timeout": pooling_config.get("socket_timeout"),
|
30
|
+
"socket_keepalive": pooling_config.get("socket_keepalive"),
|
31
|
+
"decode_responses": pooling_config.get("decode_responses"),
|
32
|
+
# health_check_interval doesn't apply to from_url method
|
33
|
+
}
|
34
|
+
redis_kwargs = {k: v for k, v in redis_kwargs.items() if v is not None}
|
35
|
+
|
36
|
+
if connection is Connection.ASYNC:
|
37
|
+
self._async_client = AsyncRedis.from_url(url, **redis_kwargs)
|
38
|
+
return self._async_client
|
39
|
+
else:
|
40
|
+
self._sync_client = SyncRedis.from_url(url, **redis_kwargs)
|
41
|
+
return self._sync_client
|
42
|
+
|
43
|
+
@overload
|
44
|
+
def get(self, connection: Literal[Connection.ASYNC]) -> AsyncRedis: ...
|
45
|
+
@overload
|
46
|
+
def get(self, connection: Literal[Connection.SYNC]) -> SyncRedis: ...
|
47
|
+
def get(
|
48
|
+
self, connection: Connection = Connection.ASYNC
|
49
|
+
) -> Union[AsyncRedis, SyncRedis]:
|
50
|
+
if connection is Connection.ASYNC:
|
51
|
+
return self._async_client
|
52
|
+
elif connection is Connection.SYNC:
|
53
|
+
return self._sync_client
|
54
|
+
|
55
|
+
async def dispose(self):
|
56
|
+
await self._async_client.close()
|
57
|
+
self._sync_client.close()
|
File without changes
|
@@ -0,0 +1,56 @@
|
|
1
|
+
from sqlalchemy.engine import create_engine, Engine
|
2
|
+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
3
|
+
from typing import Literal, Tuple, Union, overload
|
4
|
+
from ...config import MySQLDatabaseConfig
|
5
|
+
from ...enums import Connection
|
6
|
+
|
7
|
+
|
8
|
+
class MySQLEngineManager:
|
9
|
+
def __init__(self, config: MySQLDatabaseConfig) -> None:
|
10
|
+
self.config = config
|
11
|
+
self._async_engine: AsyncEngine = self._init(Connection.ASYNC)
|
12
|
+
self._sync_engine: Engine = self._init(Connection.SYNC)
|
13
|
+
|
14
|
+
@overload
|
15
|
+
def _init(self, connection: Literal[Connection.ASYNC]) -> AsyncEngine: ...
|
16
|
+
@overload
|
17
|
+
def _init(self, connection: Literal[Connection.SYNC]) -> Engine: ...
|
18
|
+
def _init(
|
19
|
+
self, connection: Connection = Connection.ASYNC
|
20
|
+
) -> Union[AsyncEngine, Engine]:
|
21
|
+
url = self.config.connection.make_url(connection)
|
22
|
+
|
23
|
+
pooling_kwargs = self.config.pooling.model_dump(
|
24
|
+
exclude={
|
25
|
+
"strategy"
|
26
|
+
}, # autocommit and connect_timeout are valid SQLAlchemy params
|
27
|
+
exclude_none=True,
|
28
|
+
)
|
29
|
+
|
30
|
+
engine_kwargs = {"echo": self.config.connection.echo, **pooling_kwargs}
|
31
|
+
|
32
|
+
if connection is Connection.ASYNC:
|
33
|
+
self._async_engine = create_async_engine(url, **engine_kwargs)
|
34
|
+
return self._async_engine
|
35
|
+
elif connection is Connection.SYNC:
|
36
|
+
self._sync_engine = create_engine(url, **engine_kwargs)
|
37
|
+
return self._sync_engine
|
38
|
+
|
39
|
+
@overload
|
40
|
+
def get(self, connection: Literal[Connection.ASYNC]) -> AsyncEngine: ...
|
41
|
+
@overload
|
42
|
+
def get(self, connection: Literal[Connection.SYNC]) -> Engine: ...
|
43
|
+
def get(
|
44
|
+
self, connection: Connection = Connection.ASYNC
|
45
|
+
) -> Union[AsyncEngine, Engine]:
|
46
|
+
if connection is Connection.ASYNC:
|
47
|
+
return self._async_engine
|
48
|
+
elif connection is Connection.SYNC:
|
49
|
+
return self._sync_engine
|
50
|
+
|
51
|
+
def get_all(self) -> Tuple[AsyncEngine, Engine]:
|
52
|
+
return (self._async_engine, self._sync_engine)
|
53
|
+
|
54
|
+
async def dispose(self):
|
55
|
+
await self._async_engine.dispose()
|
56
|
+
self._sync_engine.dispose()
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from sqlalchemy.engine import create_engine, Engine
|
2
|
+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
3
|
+
from typing import Literal, Tuple, Union, overload
|
4
|
+
from ...config import PostgreSQLDatabaseConfig
|
5
|
+
from ...enums import Connection
|
6
|
+
|
7
|
+
|
8
|
+
class PostgreSQLEngineManager:
|
9
|
+
def __init__(self, config: PostgreSQLDatabaseConfig) -> None:
|
10
|
+
self.config = config
|
11
|
+
self._async_engine: AsyncEngine = self._init(Connection.ASYNC)
|
12
|
+
self._sync_engine: Engine = self._init(Connection.SYNC)
|
13
|
+
|
14
|
+
@overload
|
15
|
+
def _init(self, connection: Literal[Connection.ASYNC]) -> AsyncEngine: ...
|
16
|
+
@overload
|
17
|
+
def _init(self, connection: Literal[Connection.SYNC]) -> Engine: ...
|
18
|
+
def _init(
|
19
|
+
self, connection: Connection = Connection.ASYNC
|
20
|
+
) -> Union[AsyncEngine, Engine]:
|
21
|
+
url = self.config.connection.make_url(connection)
|
22
|
+
|
23
|
+
pooling_kwargs = self.config.pooling.model_dump(
|
24
|
+
exclude={
|
25
|
+
"strategy",
|
26
|
+
"prepared_statement_cache_size",
|
27
|
+
"pool_reset_on_return",
|
28
|
+
},
|
29
|
+
exclude_none=True,
|
30
|
+
)
|
31
|
+
|
32
|
+
engine_kwargs = {"echo": self.config.connection.echo, **pooling_kwargs}
|
33
|
+
|
34
|
+
if connection is Connection.ASYNC:
|
35
|
+
self._async_engine = create_async_engine(url, **engine_kwargs)
|
36
|
+
return self._async_engine
|
37
|
+
elif connection is Connection.SYNC:
|
38
|
+
self._sync_engine = create_engine(url, **engine_kwargs)
|
39
|
+
return self._sync_engine
|
40
|
+
|
41
|
+
@overload
|
42
|
+
def get(self, connection: Literal[Connection.ASYNC]) -> AsyncEngine: ...
|
43
|
+
@overload
|
44
|
+
def get(self, connection: Literal[Connection.SYNC]) -> Engine: ...
|
45
|
+
def get(
|
46
|
+
self, connection: Connection = Connection.ASYNC
|
47
|
+
) -> Union[AsyncEngine, Engine]:
|
48
|
+
if connection is Connection.ASYNC:
|
49
|
+
return self._async_engine
|
50
|
+
elif connection is Connection.SYNC:
|
51
|
+
return self._sync_engine
|
52
|
+
|
53
|
+
def get_all(self) -> Tuple[AsyncEngine, Engine]:
|
54
|
+
return (self._async_engine, self._sync_engine)
|
55
|
+
|
56
|
+
async def dispose(self):
|
57
|
+
await self._async_engine.dispose()
|
58
|
+
self._sync_engine.dispose()
|