aws-advanced-python-wrapper 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. CONTRIBUTING.md +63 -0
  2. aws_advanced_python_wrapper/__init__.py +28 -0
  3. aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py +228 -0
  4. aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py +240 -0
  5. aws_advanced_python_wrapper/aws_secrets_manager_plugin.py +218 -0
  6. aws_advanced_python_wrapper/connect_time_plugin.py +69 -0
  7. aws_advanced_python_wrapper/connection_provider.py +232 -0
  8. aws_advanced_python_wrapper/database_dialect.py +708 -0
  9. aws_advanced_python_wrapper/default_plugin.py +144 -0
  10. aws_advanced_python_wrapper/developer_plugin.py +163 -0
  11. aws_advanced_python_wrapper/driver_configuration_profiles.py +44 -0
  12. aws_advanced_python_wrapper/driver_dialect.py +165 -0
  13. aws_advanced_python_wrapper/driver_dialect_codes.py +19 -0
  14. aws_advanced_python_wrapper/driver_dialect_manager.py +121 -0
  15. aws_advanced_python_wrapper/driver_info.py +18 -0
  16. aws_advanced_python_wrapper/errors.py +47 -0
  17. aws_advanced_python_wrapper/exception_handling.py +73 -0
  18. aws_advanced_python_wrapper/execute_time_plugin.py +58 -0
  19. aws_advanced_python_wrapper/failover_plugin.py +517 -0
  20. aws_advanced_python_wrapper/failover_result.py +42 -0
  21. aws_advanced_python_wrapper/fastest_response_strategy_plugin.py +345 -0
  22. aws_advanced_python_wrapper/federated_plugin.py +382 -0
  23. aws_advanced_python_wrapper/host_availability.py +86 -0
  24. aws_advanced_python_wrapper/host_list_provider.py +645 -0
  25. aws_advanced_python_wrapper/host_monitoring_plugin.py +728 -0
  26. aws_advanced_python_wrapper/host_selector.py +190 -0
  27. aws_advanced_python_wrapper/hostinfo.py +138 -0
  28. aws_advanced_python_wrapper/iam_plugin.py +195 -0
  29. aws_advanced_python_wrapper/mysql_driver_dialect.py +175 -0
  30. aws_advanced_python_wrapper/pep249.py +196 -0
  31. aws_advanced_python_wrapper/pg_driver_dialect.py +176 -0
  32. aws_advanced_python_wrapper/plugin.py +148 -0
  33. aws_advanced_python_wrapper/plugin_service.py +949 -0
  34. aws_advanced_python_wrapper/read_write_splitting_plugin.py +363 -0
  35. aws_advanced_python_wrapper/reader_failover_handler.py +252 -0
  36. aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +315 -0
  37. aws_advanced_python_wrapper/sql_alchemy_connection_provider.py +196 -0
  38. aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py +127 -0
  39. aws_advanced_python_wrapper/stale_dns_plugin.py +209 -0
  40. aws_advanced_python_wrapper/states/__init__.py +13 -0
  41. aws_advanced_python_wrapper/states/session_state.py +94 -0
  42. aws_advanced_python_wrapper/states/session_state_service.py +221 -0
  43. aws_advanced_python_wrapper/utils/__init__.py +13 -0
  44. aws_advanced_python_wrapper/utils/atomic.py +51 -0
  45. aws_advanced_python_wrapper/utils/cache_map.py +99 -0
  46. aws_advanced_python_wrapper/utils/concurrent.py +100 -0
  47. aws_advanced_python_wrapper/utils/decorators.py +70 -0
  48. aws_advanced_python_wrapper/utils/failover_mode.py +39 -0
  49. aws_advanced_python_wrapper/utils/iamutils.py +75 -0
  50. aws_advanced_python_wrapper/utils/log.py +75 -0
  51. aws_advanced_python_wrapper/utils/messages.py +36 -0
  52. aws_advanced_python_wrapper/utils/mysql_exception_handler.py +73 -0
  53. aws_advanced_python_wrapper/utils/notifications.py +37 -0
  54. aws_advanced_python_wrapper/utils/pg_exception_handler.py +115 -0
  55. aws_advanced_python_wrapper/utils/properties.py +492 -0
  56. aws_advanced_python_wrapper/utils/rds_url_type.py +36 -0
  57. aws_advanced_python_wrapper/utils/rdsutils.py +226 -0
  58. aws_advanced_python_wrapper/utils/sliding_expiration_cache.py +146 -0
  59. aws_advanced_python_wrapper/utils/telemetry/default_telemetry_factory.py +82 -0
  60. aws_advanced_python_wrapper/utils/telemetry/null_telemetry.py +55 -0
  61. aws_advanced_python_wrapper/utils/telemetry/open_telemetry.py +189 -0
  62. aws_advanced_python_wrapper/utils/telemetry/telemetry.py +85 -0
  63. aws_advanced_python_wrapper/utils/telemetry/xray_telemetry.py +126 -0
  64. aws_advanced_python_wrapper/utils/utils.py +89 -0
  65. aws_advanced_python_wrapper/wrapper.py +322 -0
  66. aws_advanced_python_wrapper/writer_failover_handler.py +347 -0
  67. aws_advanced_python_wrapper-1.0.0.dist-info/LICENSE +201 -0
  68. aws_advanced_python_wrapper-1.0.0.dist-info/METADATA +261 -0
  69. aws_advanced_python_wrapper-1.0.0.dist-info/RECORD +70 -0
  70. aws_advanced_python_wrapper-1.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,708 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License").
4
+ # You may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import (TYPE_CHECKING, Callable, ClassVar, Dict, Optional,
18
+ Protocol, Tuple, runtime_checkable)
19
+
20
+ from aws_advanced_python_wrapper.driver_info import DriverInfo
21
+
22
+ if TYPE_CHECKING:
23
+ from aws_advanced_python_wrapper.pep249 import Connection
24
+ from .driver_dialect import DriverDialect
25
+ from .exception_handling import ExceptionHandler
26
+
27
+ from abc import abstractmethod
28
+ from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
29
+ from contextlib import closing
30
+ from enum import Enum, auto
31
+
32
+ from aws_advanced_python_wrapper.errors import (AwsWrapperError,
33
+ QueryTimeoutError)
34
+ from aws_advanced_python_wrapper.host_list_provider import (
35
+ ConnectionStringHostListProvider, MultiAzHostListProvider,
36
+ RdsHostListProvider)
37
+ from aws_advanced_python_wrapper.hostinfo import HostInfo
38
+ from aws_advanced_python_wrapper.utils.decorators import \
39
+ preserve_transaction_status_with_timeout
40
+ from aws_advanced_python_wrapper.utils.log import Logger
41
+ from aws_advanced_python_wrapper.utils.properties import (Properties,
42
+ PropertiesUtils,
43
+ WrapperProperties)
44
+ from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
45
+ from .driver_dialect_codes import DriverDialectCodes
46
+ from .utils.cache_map import CacheMap
47
+ from .utils.messages import Messages
48
+ from .utils.utils import Utils
49
+
50
+ logger = Logger(__name__)
51
+
52
+
53
+ class DialectCode(Enum):
54
+ # https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/multi-az-db-clusters-concepts.html
55
+ MULTI_AZ_MYSQL = "multi-az-mysql"
56
+ AURORA_MYSQL = "aurora-mysql"
57
+ RDS_MYSQL = "rds-mysql"
58
+ MYSQL = "mysql"
59
+
60
+ MULTI_AZ_PG = "multi-az-pg"
61
+ AURORA_PG = "aurora-pg"
62
+ RDS_PG = "rds-pg"
63
+ PG = "pg"
64
+
65
+ CUSTOM = "custom"
66
+ UNKNOWN = "unknown"
67
+
68
+ @staticmethod
69
+ def from_string(value: str) -> DialectCode:
70
+ try:
71
+ return DialectCode(value)
72
+ except ValueError:
73
+ raise AwsWrapperError(Messages.get_formatted("DialectCode.InvalidStringValue", value))
74
+
75
+
76
+ class TargetDriverType(Enum):
77
+ MYSQL = auto()
78
+ POSTGRES = auto()
79
+ CUSTOM = auto()
80
+
81
+
82
+ @runtime_checkable
83
+ class TopologyAwareDatabaseDialect(Protocol):
84
+ _TOPOLOGY_QUERY: str
85
+ _HOST_ID_QUERY: str
86
+ _IS_READER_QUERY: str
87
+
88
+ @property
89
+ def topology_query(self) -> str:
90
+ return self._TOPOLOGY_QUERY
91
+
92
+ @property
93
+ def host_id_query(self) -> str:
94
+ return self._HOST_ID_QUERY
95
+
96
+ @property
97
+ def is_reader_query(self) -> str:
98
+ return self._IS_READER_QUERY
99
+
100
+
101
+ class DatabaseDialect(Protocol):
102
+ """
103
+ Database dialects help the AWS Advanced Python Driver determine what kind of underlying database is being used,
104
+ and configure details unique to specific databases.
105
+ """
106
+
107
+ @property
108
+ @abstractmethod
109
+ def default_port(self) -> int:
110
+ ...
111
+
112
+ @property
113
+ @abstractmethod
114
+ def host_alias_query(self) -> str:
115
+ ...
116
+
117
+ @property
118
+ @abstractmethod
119
+ def server_version_query(self) -> str:
120
+ ...
121
+
122
+ @property
123
+ @abstractmethod
124
+ def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
125
+ ...
126
+
127
+ @property
128
+ @abstractmethod
129
+ def exception_handler(self) -> Optional[ExceptionHandler]:
130
+ ...
131
+
132
+ @abstractmethod
133
+ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
134
+ ...
135
+
136
+ @abstractmethod
137
+ def get_host_list_provider_supplier(self) -> Callable:
138
+ ...
139
+
140
+ @abstractmethod
141
+ def prepare_conn_props(self, props: Properties):
142
+ ...
143
+
144
+
145
+ class DatabaseDialectProvider(Protocol):
146
+ def get_dialect(self, driver_dialect: str, props: Properties) -> Optional[DatabaseDialect]:
147
+ """
148
+ Returns the dialect identified by analyzing the AwsWrapperProperties.DIALECT property (if set) or the target
149
+ driver method
150
+ """
151
+ ...
152
+
153
+ def query_for_dialect(self, url: str, host_info: Optional[HostInfo], conn: Connection,
154
+ driver_dialect: DriverDialect) -> Optional[DatabaseDialect]:
155
+ """Returns the dialect identified by querying the database to identify the engine type"""
156
+ ...
157
+
158
+
159
+ class MysqlDatabaseDialect(DatabaseDialect):
160
+ _DIALECT_UPDATE_CANDIDATES: Tuple[DialectCode, ...] = (
161
+ DialectCode.AURORA_MYSQL, DialectCode.MULTI_AZ_MYSQL, DialectCode.RDS_MYSQL)
162
+ _exception_handler: Optional[ExceptionHandler] = None
163
+
164
+ @property
165
+ def default_port(self) -> int:
166
+ return 3306
167
+
168
+ @property
169
+ def host_alias_query(self) -> str:
170
+ return "SELECT CONCAT(@@hostname, ':', @@port)"
171
+
172
+ @property
173
+ def server_version_query(self) -> str:
174
+ return "SHOW VARIABLES LIKE 'version_comment'"
175
+
176
+ @property
177
+ def exception_handler(self) -> Optional[ExceptionHandler]:
178
+ if MysqlDatabaseDialect._exception_handler is None:
179
+ MysqlDatabaseDialect._exception_handler = Utils.initialize_class(
180
+ "aws_advanced_python_wrapper.utils.mysql_exception_handler.MySQLExceptionHandler")
181
+ return MysqlDatabaseDialect._exception_handler
182
+
183
+ @property
184
+ def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
185
+ return MysqlDatabaseDialect._DIALECT_UPDATE_CANDIDATES
186
+
187
+ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
188
+ initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
189
+ try:
190
+ with closing(conn.cursor()) as cursor:
191
+ cursor.execute(self.server_version_query)
192
+ for record in cursor:
193
+ for column_value in record:
194
+ if "mysql" in column_value.lower():
195
+ return True
196
+ except Exception:
197
+ if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
198
+ conn.rollback()
199
+
200
+ return False
201
+
202
+ def get_host_list_provider_supplier(self) -> Callable:
203
+ return lambda provider_service, props: ConnectionStringHostListProvider(provider_service, props)
204
+
205
+ def prepare_conn_props(self, props: Properties):
206
+ pass
207
+
208
+
209
+ class PgDatabaseDialect(DatabaseDialect):
210
+ _DIALECT_UPDATE_CANDIDATES: Tuple[DialectCode, ...] = (
211
+ DialectCode.AURORA_PG, DialectCode.MULTI_AZ_PG, DialectCode.RDS_PG)
212
+ _exception_handler: Optional[ExceptionHandler] = None
213
+
214
+ @property
215
+ def default_port(self) -> int:
216
+ return 5432
217
+
218
+ @property
219
+ def host_alias_query(self) -> str:
220
+ return "SELECT CONCAT(inet_server_addr(), ':', inet_server_port())"
221
+
222
+ @property
223
+ def server_version_query(self) -> str:
224
+ return "SELECT 'version', VERSION()"
225
+
226
+ @property
227
+ def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
228
+ return PgDatabaseDialect._DIALECT_UPDATE_CANDIDATES
229
+
230
+ @property
231
+ def exception_handler(self) -> Optional[ExceptionHandler]:
232
+ if PgDatabaseDialect._exception_handler is None:
233
+ PgDatabaseDialect._exception_handler = Utils.initialize_class(
234
+ "aws_advanced_python_wrapper.utils.pg_exception_handler.SingleAzPgExceptionHandler")
235
+ return PgDatabaseDialect._exception_handler
236
+
237
+ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
238
+ initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
239
+ try:
240
+ with closing(conn.cursor()) as cursor:
241
+ cursor.execute('SELECT 1 FROM pg_proc LIMIT 1')
242
+ if cursor.fetchone() is not None:
243
+ return True
244
+ except Exception:
245
+ if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
246
+ conn.rollback()
247
+
248
+ return False
249
+
250
+ def get_host_list_provider_supplier(self) -> Callable:
251
+ return lambda provider_service, props: ConnectionStringHostListProvider(provider_service, props)
252
+
253
+ def prepare_conn_props(self, props: Properties):
254
+ pass
255
+
256
+
257
+ class RdsMysqlDialect(MysqlDatabaseDialect):
258
+ _DIALECT_UPDATE_CANDIDATES = (DialectCode.AURORA_MYSQL, DialectCode.MULTI_AZ_MYSQL)
259
+
260
+ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
261
+ initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
262
+ try:
263
+ with closing(conn.cursor()) as cursor:
264
+ cursor.execute(self.server_version_query)
265
+ for record in cursor:
266
+ for column_value in record:
267
+ if "source distribution" in column_value.lower():
268
+ return True
269
+ except Exception:
270
+ if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
271
+ conn.rollback()
272
+
273
+ return False
274
+
275
+ @property
276
+ def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
277
+ return RdsMysqlDialect._DIALECT_UPDATE_CANDIDATES
278
+
279
+
280
+ class RdsPgDialect(PgDatabaseDialect):
281
+ _EXTENSIONS_QUERY = ("SELECT (setting LIKE '%rds_tools%') AS rds_tools, "
282
+ "(setting LIKE '%aurora_stat_utils%') AS aurora_stat_utils "
283
+ "FROM pg_settings "
284
+ "WHERE name='rds.extensions'")
285
+ _DIALECT_UPDATE_CANDIDATES = (DialectCode.AURORA_PG, DialectCode.MULTI_AZ_PG)
286
+
287
+ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
288
+ initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
289
+ if not super().is_dialect(conn, driver_dialect):
290
+ return False
291
+
292
+ try:
293
+ with closing(conn.cursor()) as cursor:
294
+ cursor.execute(RdsPgDialect._EXTENSIONS_QUERY)
295
+ for row in cursor:
296
+ rds_tools = bool(row[0])
297
+ aurora_utils = bool(row[1])
298
+ logger.debug(
299
+ "RdsPgDialect.RdsToolsAuroraUtils", str(rds_tools), str(aurora_utils))
300
+ if rds_tools and not aurora_utils:
301
+ return True
302
+
303
+ except Exception:
304
+ if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
305
+ conn.rollback()
306
+ return False
307
+
308
+ @property
309
+ def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
310
+ return RdsPgDialect._DIALECT_UPDATE_CANDIDATES
311
+
312
+
313
+ class AuroraMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDialect):
314
+ _DIALECT_UPDATE_CANDIDATES = (DialectCode.MULTI_AZ_MYSQL,)
315
+ _TOPOLOGY_QUERY = ("SELECT SERVER_ID, CASE WHEN SESSION_ID = 'MASTER_SESSION_ID' THEN TRUE ELSE FALSE END, "
316
+ "CPU, REPLICA_LAG_IN_MILLISECONDS, LAST_UPDATE_TIMESTAMP "
317
+ "FROM information_schema.replica_host_status "
318
+ "WHERE time_to_sec(timediff(now(), LAST_UPDATE_TIMESTAMP)) <= 300 "
319
+ "OR SESSION_ID = 'MASTER_SESSION_ID' ")
320
+ _HOST_ID_QUERY = "SELECT @@aurora_server_id"
321
+ _IS_READER_QUERY = "SELECT @@innodb_read_only"
322
+
323
+ @property
324
+ def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
325
+ return AuroraMysqlDialect._DIALECT_UPDATE_CANDIDATES
326
+
327
+ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
328
+ initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
329
+ try:
330
+ with closing(conn.cursor()) as cursor:
331
+ cursor.execute("SHOW VARIABLES LIKE 'aurora_version'")
332
+ # If variable with such a name is presented then it means it's an Aurora cluster
333
+ if cursor.fetchone() is not None:
334
+ return True
335
+ except Exception:
336
+ if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
337
+ conn.rollback()
338
+
339
+ return False
340
+
341
+ def get_host_list_provider_supplier(self) -> Callable:
342
+ return lambda provider_service, props: RdsHostListProvider(provider_service, props)
343
+
344
+
345
+ class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect):
346
+ _DIALECT_UPDATE_CANDIDATES: Tuple[DialectCode, ...] = (DialectCode.MULTI_AZ_PG,)
347
+
348
+ _EXTENSIONS_QUERY = "SELECT (setting LIKE '%aurora_stat_utils%') AS aurora_stat_utils " \
349
+ "FROM pg_settings WHERE name='rds.extensions'"
350
+
351
+ _HAS_TOPOLOGY_QUERY = "SELECT 1 FROM aurora_replica_status() LIMIT 1"
352
+
353
+ _TOPOLOGY_QUERY = \
354
+ ("SELECT SERVER_ID, CASE WHEN SESSION_ID = 'MASTER_SESSION_ID' THEN TRUE ELSE FALSE END, "
355
+ "CPU, COALESCE(REPLICA_LAG_IN_MSEC, 0), LAST_UPDATE_TIMESTAMP "
356
+ "FROM aurora_replica_status() "
357
+ "WHERE EXTRACT(EPOCH FROM(NOW() - LAST_UPDATE_TIMESTAMP)) <= 300 OR SESSION_ID = 'MASTER_SESSION_ID' "
358
+ "OR LAST_UPDATE_TIMESTAMP IS NULL")
359
+
360
+ _HOST_ID_QUERY = "SELECT aurora_db_instance_identifier()"
361
+ _IS_READER_QUERY = "SELECT pg_is_in_recovery()"
362
+
363
+ @property
364
+ def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
365
+ return AuroraPgDialect._DIALECT_UPDATE_CANDIDATES
366
+
367
+ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
368
+ if not super().is_dialect(conn, driver_dialect):
369
+ return False
370
+
371
+ has_extensions: bool = False
372
+ has_topology: bool = False
373
+
374
+ initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
375
+ try:
376
+ with closing(conn.cursor()) as cursor:
377
+ cursor.execute(self._EXTENSIONS_QUERY)
378
+ row = cursor.fetchone()
379
+ if row and bool(row[0]):
380
+ logger.debug("AuroraPgDialect.HasExtensionsTrue")
381
+ has_extensions = True
382
+
383
+ with closing(conn.cursor()) as cursor:
384
+ cursor.execute(self._HAS_TOPOLOGY_QUERY)
385
+ if cursor.fetchone() is not None:
386
+ logger.debug("AuroraPgDialect.HasTopologyTrue")
387
+ has_topology = True
388
+
389
+ return has_extensions and has_topology
390
+ except Exception:
391
+ if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
392
+ conn.rollback()
393
+
394
+ return False
395
+
396
+ def get_host_list_provider_supplier(self) -> Callable:
397
+ return lambda provider_service, props: RdsHostListProvider(provider_service, props)
398
+
399
+
400
+ class MultiAzMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDialect):
401
+ _TOPOLOGY_QUERY = "SELECT id, endpoint, port FROM mysql.rds_topology"
402
+ _WRITER_HOST_QUERY = "SHOW REPLICA STATUS"
403
+ _WRITER_HOST_COLUMN_INDEX = 39
404
+ _HOST_ID_QUERY = "SELECT @@server_id"
405
+ _IS_READER_QUERY = "SELECT @@read_only"
406
+
407
+ @property
408
+ def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
409
+ return None
410
+
411
+ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
412
+ initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
413
+ try:
414
+ with closing(conn.cursor()) as cursor:
415
+ cursor.execute(MultiAzMysqlDialect._TOPOLOGY_QUERY)
416
+ records = cursor.fetchall()
417
+ if records is not None and len(records) > 0:
418
+ return True
419
+ except Exception:
420
+ if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
421
+ conn.rollback()
422
+
423
+ return False
424
+
425
+ def get_host_list_provider_supplier(self) -> Callable:
426
+ return lambda provider_service, props: MultiAzHostListProvider(
427
+ provider_service,
428
+ props,
429
+ self._TOPOLOGY_QUERY,
430
+ self._HOST_ID_QUERY,
431
+ self._IS_READER_QUERY,
432
+ self._WRITER_HOST_QUERY,
433
+ self._WRITER_HOST_COLUMN_INDEX)
434
+
435
+ def prepare_conn_props(self, props: Properties):
436
+ # These props are added for RDS metrics purposes, they are not required for functional correctness.
437
+ # The "conn_attrs" property value is specified as a dict.
438
+ extra_conn_attrs = {
439
+ "python_wrapper_name": "aws_python_driver",
440
+ "python_wrapper_version": DriverInfo.DRIVER_VERSION}
441
+ conn_attrs = props.get("conn_attrs")
442
+ if conn_attrs is None:
443
+ props["conn_attrs"] = extra_conn_attrs
444
+ else:
445
+ props["conn_attrs"].update(extra_conn_attrs)
446
+
447
+
448
+ class MultiAzPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect):
449
+ # The driver name passed to show_topology is used for RDS metrics purposes.
450
+ # It is not required for functional correctness.
451
+ _TOPOLOGY_QUERY = \
452
+ f"SELECT id, endpoint, port FROM rds_tools.show_topology('aws_python_driver-{DriverInfo.DRIVER_VERSION}')"
453
+ _WRITER_HOST_QUERY = \
454
+ "SELECT multi_az_db_cluster_source_dbi_resource_id FROM rds_tools.multi_az_db_cluster_source_dbi_resource_id()"
455
+ _HOST_ID_QUERY = "SELECT dbi_resource_id FROM rds_tools.dbi_resource_id()"
456
+ _IS_READER_QUERY = "SELECT pg_is_in_recovery()"
457
+ _exception_handler: Optional[ExceptionHandler] = None
458
+
459
+ @property
460
+ def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
461
+ return None
462
+
463
+ @property
464
+ def exception_handler(self) -> Optional[ExceptionHandler]:
465
+ if MultiAzPgDialect._exception_handler is None:
466
+ MultiAzPgDialect._exception_handler = Utils.initialize_class(
467
+ "aws_advanced_python_wrapper.utils.pg_exception_handler.MultiAzPgExceptionHandler")
468
+ return MultiAzPgDialect._exception_handler
469
+
470
+ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
471
+ initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
472
+ try:
473
+ with closing(conn.cursor()) as cursor:
474
+ cursor.execute(MultiAzPgDialect._WRITER_HOST_QUERY)
475
+ if cursor.fetchone() is not None:
476
+ return True
477
+ except Exception:
478
+ if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
479
+ conn.rollback()
480
+
481
+ return False
482
+
483
+ def get_host_list_provider_supplier(self) -> Callable:
484
+ return lambda provider_service, props: MultiAzHostListProvider(
485
+ provider_service,
486
+ props,
487
+ self._TOPOLOGY_QUERY,
488
+ self._HOST_ID_QUERY,
489
+ self._IS_READER_QUERY,
490
+ self._WRITER_HOST_QUERY)
491
+
492
+
493
+ class UnknownDatabaseDialect(DatabaseDialect):
494
+ _DIALECT_UPDATE_CANDIDATES: Optional[Tuple[DialectCode, ...]] = \
495
+ (DialectCode.MYSQL,
496
+ DialectCode.PG,
497
+ DialectCode.RDS_MYSQL,
498
+ DialectCode.RDS_PG,
499
+ DialectCode.AURORA_MYSQL,
500
+ DialectCode.AURORA_PG,
501
+ DialectCode.MULTI_AZ_MYSQL,
502
+ DialectCode.MULTI_AZ_PG)
503
+
504
+ @property
505
+ def default_port(self) -> int:
506
+ return HostInfo.NO_PORT
507
+
508
+ @property
509
+ def host_alias_query(self) -> str:
510
+ return ""
511
+
512
+ @property
513
+ def server_version_query(self) -> str:
514
+ return ""
515
+
516
+ @property
517
+ def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
518
+ return UnknownDatabaseDialect._DIALECT_UPDATE_CANDIDATES
519
+
520
+ @property
521
+ def exception_handler(self) -> Optional[ExceptionHandler]:
522
+ return None
523
+
524
+ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
525
+ return False
526
+
527
+ def get_host_list_provider_supplier(self) -> Callable:
528
+ return lambda provider_service, props: ConnectionStringHostListProvider(provider_service, props)
529
+
530
+ def prepare_conn_props(self, props: Properties):
531
+ pass
532
+
533
+
534
+ class DatabaseDialectManager(DatabaseDialectProvider):
535
+ _ENDPOINT_CACHE_EXPIRATION_NS = 30 * 60_000_000_000 # 30 minutes
536
+ _known_endpoint_dialects: CacheMap[str, DialectCode] = CacheMap()
537
+ _custom_dialect: Optional[DatabaseDialect] = None
538
+ _executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="DatabaseDialectManagerExecutor")
539
+ _known_dialects_by_code: Dict[DialectCode, DatabaseDialect] = {
540
+ DialectCode.MYSQL: MysqlDatabaseDialect(),
541
+ DialectCode.RDS_MYSQL: RdsMysqlDialect(),
542
+ DialectCode.AURORA_MYSQL: AuroraMysqlDialect(),
543
+ DialectCode.MULTI_AZ_MYSQL: MultiAzMysqlDialect(),
544
+ DialectCode.PG: PgDatabaseDialect(),
545
+ DialectCode.RDS_PG: RdsPgDialect(),
546
+ DialectCode.AURORA_PG: AuroraPgDialect(),
547
+ DialectCode.MULTI_AZ_PG: MultiAzPgDialect(),
548
+ DialectCode.UNKNOWN: UnknownDatabaseDialect()
549
+ }
550
+
551
+ def __init__(self, props: Properties, rds_helper: Optional[RdsUtils] = None):
552
+ self._props: Properties = props
553
+ self._rds_helper: RdsUtils = rds_helper if rds_helper else RdsUtils()
554
+ self._can_update: bool = False
555
+ self._dialect: DatabaseDialect = UnknownDatabaseDialect()
556
+ self._dialect_code: DialectCode = DialectCode.UNKNOWN
557
+
558
+ @staticmethod
559
+ def get_custom_dialect():
560
+ return DatabaseDialectManager._custom_dialect
561
+
562
+ @staticmethod
563
+ def set_custom_dialect(dialect: DatabaseDialect):
564
+ DatabaseDialectManager._custom_dialect = dialect
565
+
566
+ @staticmethod
567
+ def reset_custom_dialect():
568
+ DatabaseDialectManager._custom_dialect = None
569
+
570
+ def reset_endpoint_cache(self):
571
+ DatabaseDialectManager._known_endpoint_dialects.clear()
572
+
573
+ def get_dialect(self, driver_dialect: str, props: Properties) -> DatabaseDialect:
574
+ self._can_update = False
575
+
576
+ if self._custom_dialect is not None:
577
+ self._dialect_code = DialectCode.CUSTOM
578
+ self._dialect = self._custom_dialect
579
+ self._log_current_dialect()
580
+ return self._dialect
581
+
582
+ user_dialect_setting: Optional[str] = WrapperProperties.DIALECT.get(props)
583
+ url = PropertiesUtils.get_url(props)
584
+
585
+ if user_dialect_setting is None:
586
+ dialect_code = DatabaseDialectManager._known_endpoint_dialects.get(url)
587
+ else:
588
+ dialect_code = DialectCode.from_string(user_dialect_setting)
589
+
590
+ if dialect_code is not None:
591
+ dialect: Optional[DatabaseDialect] = DatabaseDialectManager._known_dialects_by_code.get(dialect_code)
592
+ if dialect:
593
+ self._dialect_code = dialect_code
594
+ self._dialect = dialect
595
+ self._log_current_dialect()
596
+ return dialect
597
+ else:
598
+ raise AwsWrapperError(
599
+ Messages.get_formatted("DatabaseDialectManager.UnknownDialectCode", str(dialect_code)))
600
+
601
+ host: str = props["host"]
602
+ target_driver_type: TargetDriverType = self._get_target_driver_type(driver_dialect)
603
+ if target_driver_type is TargetDriverType.MYSQL:
604
+ rds_type = self._rds_helper.identify_rds_type(host)
605
+ if rds_type.is_rds_cluster:
606
+ self._can_update = True
607
+ self._dialect_code = DialectCode.AURORA_MYSQL
608
+ self._dialect = DatabaseDialectManager._known_dialects_by_code[DialectCode.AURORA_MYSQL]
609
+ return self._dialect
610
+ if rds_type.is_rds:
611
+ self._can_update = True
612
+ self._dialect_code = DialectCode.RDS_MYSQL
613
+ self._dialect = DatabaseDialectManager._known_dialects_by_code[DialectCode.RDS_MYSQL]
614
+ self._log_current_dialect()
615
+ return self._dialect
616
+ self._can_update = True
617
+ self._dialect_code = DialectCode.MYSQL
618
+ self._dialect = DatabaseDialectManager._known_dialects_by_code[DialectCode.MYSQL]
619
+ self._log_current_dialect()
620
+ return self._dialect
621
+
622
+ if target_driver_type is TargetDriverType.POSTGRES:
623
+ rds_type = self._rds_helper.identify_rds_type(host)
624
+ if rds_type.is_rds_cluster:
625
+ self._can_update = True
626
+ self._dialect_code = DialectCode.AURORA_PG
627
+ self._dialect = DatabaseDialectManager._known_dialects_by_code[DialectCode.AURORA_PG]
628
+ return self._dialect
629
+ if rds_type.is_rds:
630
+ self._can_update = True
631
+ self._dialect_code = DialectCode.RDS_PG
632
+ self._dialect = DatabaseDialectManager._known_dialects_by_code[DialectCode.RDS_PG]
633
+ self._log_current_dialect()
634
+ return self._dialect
635
+ self._can_update = True
636
+ self._dialect_code = DialectCode.PG
637
+ self._dialect = DatabaseDialectManager._known_dialects_by_code[DialectCode.PG]
638
+ self._log_current_dialect()
639
+ return self._dialect
640
+
641
+ self._can_update = True
642
+ self._dialect_code = DialectCode.UNKNOWN
643
+ self._dialect = DatabaseDialectManager._known_dialects_by_code[DialectCode.UNKNOWN]
644
+ self._log_current_dialect()
645
+ return self._dialect
646
+
647
+ def _get_target_driver_type(self, driver_dialect: str) -> TargetDriverType:
648
+ if driver_dialect == DriverDialectCodes.PSYCOPG:
649
+ return TargetDriverType.POSTGRES
650
+ if driver_dialect == DriverDialectCodes.MYSQL_CONNECTOR_PYTHON:
651
+ return TargetDriverType.MYSQL
652
+
653
+ return TargetDriverType.CUSTOM
654
+
655
+ def query_for_dialect(self, url: str, host_info: Optional[HostInfo], conn: Connection,
656
+ driver_dialect: DriverDialect) -> DatabaseDialect:
657
+ if not self._can_update:
658
+ self._log_current_dialect()
659
+ return self._dialect
660
+
661
+ dialect_candidates = self._dialect.dialect_update_candidates if self._dialect is not None else None
662
+ if dialect_candidates is not None:
663
+ for dialect_code in dialect_candidates:
664
+ dialect_candidate = DatabaseDialectManager._known_dialects_by_code.get(dialect_code)
665
+ if dialect_candidate is None:
666
+ raise AwsWrapperError(Messages.get_formatted("DatabaseDialectManager.UnknownDialectCode", dialect_code))
667
+
668
+ timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get(self._props)
669
+ try:
670
+ cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
671
+ DatabaseDialectManager._executor,
672
+ timeout_sec,
673
+ driver_dialect,
674
+ conn)(dialect_candidate.is_dialect)
675
+ is_dialect = cursor_execute_func_with_timeout(conn, driver_dialect)
676
+ except TimeoutError as e:
677
+ raise QueryTimeoutError("DatabaseDialectManager.QueryForDialectTimeout") from e
678
+
679
+ if not is_dialect:
680
+ continue
681
+
682
+ self._can_update = False
683
+ self._dialect_code = dialect_code
684
+ self._dialect = dialect_candidate
685
+ DatabaseDialectManager._known_endpoint_dialects.put(url, dialect_code,
686
+ DatabaseDialectManager._ENDPOINT_CACHE_EXPIRATION_NS)
687
+ if host_info is not None:
688
+ DatabaseDialectManager._known_endpoint_dialects.put(
689
+ host_info.url, dialect_code, DatabaseDialectManager._ENDPOINT_CACHE_EXPIRATION_NS)
690
+
691
+ self._log_current_dialect()
692
+ return self._dialect
693
+
694
+ if self._dialect_code is None or self._dialect_code == DialectCode.UNKNOWN:
695
+ raise AwsWrapperError(Messages.get("DatabaseDialectManager.UnknownDialect"))
696
+
697
+ self._can_update = False
698
+ DatabaseDialectManager._known_endpoint_dialects.put(url, self._dialect_code,
699
+ DatabaseDialectManager._ENDPOINT_CACHE_EXPIRATION_NS)
700
+ if host_info is not None:
701
+ DatabaseDialectManager._known_endpoint_dialects.put(
702
+ host_info.url, self._dialect_code, DatabaseDialectManager._ENDPOINT_CACHE_EXPIRATION_NS)
703
+ self._log_current_dialect()
704
+ return self._dialect
705
+
706
+ def _log_current_dialect(self):
707
+ dialect_class = "<null>" if self._dialect is None else type(self._dialect).__name__
708
+ logger.debug("DatabaseDialectManager.CurrentDialectCanUpdate", self._dialect_code, dialect_class, self._can_update)