aws-advanced-python-wrapper 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. CONTRIBUTING.md +63 -0
  2. aws_advanced_python_wrapper/__init__.py +28 -0
  3. aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py +228 -0
  4. aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py +240 -0
  5. aws_advanced_python_wrapper/aws_secrets_manager_plugin.py +218 -0
  6. aws_advanced_python_wrapper/connect_time_plugin.py +69 -0
  7. aws_advanced_python_wrapper/connection_provider.py +232 -0
  8. aws_advanced_python_wrapper/database_dialect.py +708 -0
  9. aws_advanced_python_wrapper/default_plugin.py +144 -0
  10. aws_advanced_python_wrapper/developer_plugin.py +163 -0
  11. aws_advanced_python_wrapper/driver_configuration_profiles.py +44 -0
  12. aws_advanced_python_wrapper/driver_dialect.py +165 -0
  13. aws_advanced_python_wrapper/driver_dialect_codes.py +19 -0
  14. aws_advanced_python_wrapper/driver_dialect_manager.py +121 -0
  15. aws_advanced_python_wrapper/driver_info.py +18 -0
  16. aws_advanced_python_wrapper/errors.py +47 -0
  17. aws_advanced_python_wrapper/exception_handling.py +73 -0
  18. aws_advanced_python_wrapper/execute_time_plugin.py +58 -0
  19. aws_advanced_python_wrapper/failover_plugin.py +517 -0
  20. aws_advanced_python_wrapper/failover_result.py +42 -0
  21. aws_advanced_python_wrapper/fastest_response_strategy_plugin.py +345 -0
  22. aws_advanced_python_wrapper/federated_plugin.py +382 -0
  23. aws_advanced_python_wrapper/host_availability.py +86 -0
  24. aws_advanced_python_wrapper/host_list_provider.py +645 -0
  25. aws_advanced_python_wrapper/host_monitoring_plugin.py +728 -0
  26. aws_advanced_python_wrapper/host_selector.py +190 -0
  27. aws_advanced_python_wrapper/hostinfo.py +138 -0
  28. aws_advanced_python_wrapper/iam_plugin.py +195 -0
  29. aws_advanced_python_wrapper/mysql_driver_dialect.py +175 -0
  30. aws_advanced_python_wrapper/pep249.py +196 -0
  31. aws_advanced_python_wrapper/pg_driver_dialect.py +176 -0
  32. aws_advanced_python_wrapper/plugin.py +148 -0
  33. aws_advanced_python_wrapper/plugin_service.py +949 -0
  34. aws_advanced_python_wrapper/read_write_splitting_plugin.py +363 -0
  35. aws_advanced_python_wrapper/reader_failover_handler.py +252 -0
  36. aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +315 -0
  37. aws_advanced_python_wrapper/sql_alchemy_connection_provider.py +196 -0
  38. aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py +127 -0
  39. aws_advanced_python_wrapper/stale_dns_plugin.py +209 -0
  40. aws_advanced_python_wrapper/states/__init__.py +13 -0
  41. aws_advanced_python_wrapper/states/session_state.py +94 -0
  42. aws_advanced_python_wrapper/states/session_state_service.py +221 -0
  43. aws_advanced_python_wrapper/utils/__init__.py +13 -0
  44. aws_advanced_python_wrapper/utils/atomic.py +51 -0
  45. aws_advanced_python_wrapper/utils/cache_map.py +99 -0
  46. aws_advanced_python_wrapper/utils/concurrent.py +100 -0
  47. aws_advanced_python_wrapper/utils/decorators.py +70 -0
  48. aws_advanced_python_wrapper/utils/failover_mode.py +39 -0
  49. aws_advanced_python_wrapper/utils/iamutils.py +75 -0
  50. aws_advanced_python_wrapper/utils/log.py +75 -0
  51. aws_advanced_python_wrapper/utils/messages.py +36 -0
  52. aws_advanced_python_wrapper/utils/mysql_exception_handler.py +73 -0
  53. aws_advanced_python_wrapper/utils/notifications.py +37 -0
  54. aws_advanced_python_wrapper/utils/pg_exception_handler.py +115 -0
  55. aws_advanced_python_wrapper/utils/properties.py +492 -0
  56. aws_advanced_python_wrapper/utils/rds_url_type.py +36 -0
  57. aws_advanced_python_wrapper/utils/rdsutils.py +226 -0
  58. aws_advanced_python_wrapper/utils/sliding_expiration_cache.py +146 -0
  59. aws_advanced_python_wrapper/utils/telemetry/default_telemetry_factory.py +82 -0
  60. aws_advanced_python_wrapper/utils/telemetry/null_telemetry.py +55 -0
  61. aws_advanced_python_wrapper/utils/telemetry/open_telemetry.py +189 -0
  62. aws_advanced_python_wrapper/utils/telemetry/telemetry.py +85 -0
  63. aws_advanced_python_wrapper/utils/telemetry/xray_telemetry.py +126 -0
  64. aws_advanced_python_wrapper/utils/utils.py +89 -0
  65. aws_advanced_python_wrapper/wrapper.py +322 -0
  66. aws_advanced_python_wrapper/writer_failover_handler.py +347 -0
  67. aws_advanced_python_wrapper-1.0.0.dist-info/LICENSE +201 -0
  68. aws_advanced_python_wrapper-1.0.0.dist-info/METADATA +261 -0
  69. aws_advanced_python_wrapper-1.0.0.dist-info/RECORD +70 -0
  70. aws_advanced_python_wrapper-1.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,190 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License").
4
+ # You may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import random
18
+ from re import search
19
+ from typing import TYPE_CHECKING, Dict, List, Optional, Protocol, Tuple
20
+
21
+ from .host_availability import HostAvailability
22
+
23
+ if TYPE_CHECKING:
24
+ from .hostinfo import HostInfo, HostRole
25
+
26
+ from aws_advanced_python_wrapper.errors import AwsWrapperError
27
+ from aws_advanced_python_wrapper.utils.cache_map import CacheMap
28
+ from .pep249 import Error
29
+ from .utils.messages import Messages
30
+ from .utils.properties import Properties, WrapperProperties
31
+
32
+
33
+ class HostSelector(Protocol):
34
+ """
35
+ Interface for a strategy defining how to pick a host from a list of hosts.
36
+ """
37
+
38
+ def get_host(self, hosts: Tuple[HostInfo, ...], role: HostRole, props: Optional[Properties] = None) -> HostInfo:
39
+ ...
40
+
41
+
42
+ class RandomHostSelector(HostSelector):
43
+
44
+ def get_host(self, hosts: Tuple[HostInfo, ...], role: HostRole, props: Optional[Properties] = None) -> HostInfo:
45
+
46
+ eligible_hosts = [host for host in hosts if host.role == role and host.get_availability() == HostAvailability.AVAILABLE]
47
+
48
+ if len(eligible_hosts) == 0:
49
+ raise Error(Messages.get("HostSelector.NoEligibleHost"))
50
+
51
+ return random.choice(eligible_hosts)
52
+
53
+
54
+ class RoundRobinClusterInfo:
55
+ _last_host: Optional[HostInfo] = None
56
+ _cluster_weights_dict: Dict[str, int] = {}
57
+ _default_weight: int = 1
58
+ _weight_counter: int = 0
59
+
60
+ @property
61
+ def last_host(self) -> Optional[HostInfo]:
62
+ return self._last_host
63
+
64
+ @last_host.setter
65
+ def last_host(self, value):
66
+ self._last_host = value
67
+
68
+ @property
69
+ def cluster_weights_dict(self) -> Dict[str, int]:
70
+ return self._cluster_weights_dict
71
+
72
+ @cluster_weights_dict.setter
73
+ def cluster_weights_dict(self, value):
74
+ self._cluster_weights_dict = value
75
+
76
+ @property
77
+ def default_weight(self):
78
+ return self._default_weight
79
+
80
+ @default_weight.setter
81
+ def default_weight(self, value):
82
+ self._default_weight = value
83
+
84
+ @property
85
+ def weight_counter(self) -> int:
86
+ return self._weight_counter
87
+
88
+ @weight_counter.setter
89
+ def weight_counter(self, value):
90
+ self._weight_counter = value
91
+
92
+
93
+ class RoundRobinHostSelector(HostSelector):
94
+ _DEFAULT_WEIGHT: int = 1
95
+ _DEFAULT_ROUND_ROBIN_CACHE_EXPIRE_NANOS = 60000000000 * 10 # 10 minutes
96
+ _HOST_WEIGHT_PAIRS_PATTERN = r"((?P<host>[^:/?#]*):(?P<weight>.*))"
97
+ _round_robin_cache: CacheMap[str, Optional[RoundRobinClusterInfo]] = CacheMap()
98
+
99
+ def get_host(self, hosts: Tuple[HostInfo, ...], role: HostRole, props: Optional[Properties] = None) -> HostInfo:
100
+
101
+ eligible_hosts: List[HostInfo] = [host for host in hosts if host.role == role and host.get_availability() == HostAvailability.AVAILABLE]
102
+ eligible_hosts.sort(key=lambda host: host.host, reverse=False)
103
+ if len(eligible_hosts) == 0:
104
+ raise AwsWrapperError(Messages.get_formatted("HostSelector.NoHostsMatchingRole", role))
105
+
106
+ # Create new cache entries for provided hosts if necessary. All hosts point to the same cluster info.
107
+ self._create_cache_entry_for_hosts(eligible_hosts, props)
108
+ current_cluster_info_key: str = eligible_hosts[0].host
109
+ cluster_info: Optional[RoundRobinClusterInfo] = RoundRobinHostSelector._round_robin_cache.get(current_cluster_info_key)
110
+
111
+ last_host_index: int = -1
112
+ if cluster_info is None:
113
+ raise AwsWrapperError(Messages.get("RoundRobinHostSelector.ClusterInfoNone"))
114
+
115
+ last_host = cluster_info.last_host
116
+ # Check if last_host is in list of eligible hosts. Update last_host_index.
117
+ if last_host is not None:
118
+ for i in range(0, len(eligible_hosts)):
119
+ if eligible_hosts[i].host == last_host.host:
120
+ last_host_index = i
121
+
122
+ if cluster_info.weight_counter > 0 and last_host_index != -1:
123
+ target_host_index = last_host_index
124
+ else:
125
+ if last_host_index != -1 and last_host_index != (len(eligible_hosts) - 1):
126
+ target_host_index = last_host_index + 1
127
+ else:
128
+ target_host_index = 0
129
+ weight = cluster_info.cluster_weights_dict.get(eligible_hosts[target_host_index].host)
130
+ cluster_info.weight_counter = cluster_info.default_weight if weight is None else weight
131
+
132
+ cluster_info.weight_counter = (cluster_info.weight_counter - 1)
133
+ cluster_info.last_host = eligible_hosts[target_host_index]
134
+ return eligible_hosts[target_host_index]
135
+
136
+ def _create_cache_entry_for_hosts(self, hosts: List[HostInfo], props: Optional[Properties]) -> None:
137
+ cached_info = None
138
+ for host in hosts:
139
+ info = self._round_robin_cache.get(host.host)
140
+ if info is not None:
141
+ cached_info = info
142
+ break
143
+ if cached_info is not None:
144
+ for host in hosts:
145
+ # Update the expiration time
146
+ self._round_robin_cache.put(
147
+ host.host, cached_info, RoundRobinHostSelector._DEFAULT_ROUND_ROBIN_CACHE_EXPIRE_NANOS)
148
+ else:
149
+ round_robin_cluster_info: RoundRobinClusterInfo = RoundRobinClusterInfo()
150
+ self._update_cache_properties_for_round_robin_cluster_info(round_robin_cluster_info, props)
151
+ for host in hosts:
152
+ self._round_robin_cache.put(
153
+ host.host, round_robin_cluster_info, RoundRobinHostSelector._DEFAULT_ROUND_ROBIN_CACHE_EXPIRE_NANOS)
154
+
155
+ def _update_cache_properties_for_round_robin_cluster_info(self, round_robin_cluster_info: RoundRobinClusterInfo, props: Optional[Properties]):
156
+ cluster_default_weight: int = RoundRobinHostSelector._DEFAULT_WEIGHT
157
+ if props is not None:
158
+ props_weight = WrapperProperties.ROUND_ROBIN_DEFAULT_WEIGHT.get_int(props)
159
+ if props_weight < RoundRobinHostSelector._DEFAULT_WEIGHT:
160
+ raise AwsWrapperError(Messages.get("RoundRobinHostSelector.RoundRobinInvalidDefaultWeight"))
161
+ cluster_default_weight = props_weight
162
+ round_robin_cluster_info.default_weight = cluster_default_weight
163
+
164
+ if props is not None:
165
+ host_weights: Optional[str] = WrapperProperties.ROUND_ROBIN_HOST_WEIGHT_PAIRS.get(props)
166
+ if host_weights is not None and len(host_weights) != 0:
167
+ host_weight_pairs: List[str] = host_weights.split(",")
168
+
169
+ for pair in host_weight_pairs:
170
+ match = search(RoundRobinHostSelector._HOST_WEIGHT_PAIRS_PATTERN, pair)
171
+ if match:
172
+ host_name = match.group("host")
173
+ host_weight = match.group("weight")
174
+ else:
175
+ raise AwsWrapperError(Messages.get("RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs"))
176
+
177
+ if len(host_name) == 0 or len(host_weight) == 0:
178
+ raise AwsWrapperError(Messages.get("RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs"))
179
+ try:
180
+ weight: int = int(host_weight)
181
+
182
+ if weight < RoundRobinHostSelector._DEFAULT_WEIGHT:
183
+ raise AwsWrapperError(Messages.get("RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs"))
184
+
185
+ round_robin_cluster_info.cluster_weights_dict[host_name] = weight
186
+ except ValueError:
187
+ raise AwsWrapperError(Messages.get("RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs"))
188
+
189
+ def clear_cache(self):
190
+ RoundRobinHostSelector._round_robin_cache.clear()
@@ -0,0 +1,138 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License").
4
+ # You may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from dataclasses import dataclass
18
+ from enum import Enum, auto
19
+ from typing import TYPE_CHECKING, ClassVar, FrozenSet, Optional, Set
20
+
21
+ from aws_advanced_python_wrapper.host_availability import (
22
+ HostAvailability, HostAvailabilityStrategy)
23
+
24
+ if TYPE_CHECKING:
25
+ from datetime import datetime
26
+
27
+
28
+ class HostRole(Enum):
29
+ UNKNOWN = auto()
30
+ READER = auto()
31
+ WRITER = auto()
32
+
33
+
34
+ @dataclass(eq=False)
35
+ class HostInfo:
36
+ NO_PORT: ClassVar[int] = -1
37
+ DEFAULT_WEIGHT = 100
38
+
39
+ def __init__(
40
+ self,
41
+ host: str,
42
+ port: int = NO_PORT,
43
+ role: HostRole = HostRole.WRITER,
44
+ availability: HostAvailability = HostAvailability.AVAILABLE,
45
+ host_availability_strategy=HostAvailabilityStrategy(),
46
+ weight: int = DEFAULT_WEIGHT,
47
+ host_id: Optional[str] = None,
48
+ last_update_time: Optional[datetime] = None):
49
+ self.host = host
50
+ self.port = port
51
+ self.role = role
52
+ self._availability = availability
53
+ self.host_availability_strategy = host_availability_strategy
54
+ self.weight = weight,
55
+ self.host_id = host_id
56
+ self.last_update_time = last_update_time
57
+
58
+ self._aliases: Set[str] = set()
59
+ self._all_aliases: Set[str] = {self.as_alias()}
60
+
61
+ def __eq__(self, other: object):
62
+ if self is object:
63
+ return True
64
+ if not isinstance(other, HostInfo):
65
+ return False
66
+
67
+ return self.host == other.host \
68
+ and self.port == other.port \
69
+ and self._availability == other._availability \
70
+ and self.role == other.role
71
+
72
+ def __str__(self):
73
+ return f"HostInfo({self.host}, {self.port}, {self.role}, {self._availability})"
74
+
75
+ @property
76
+ def url(self):
77
+ if self.is_port_specified():
78
+ return f"{self.host}:{self.port}"
79
+ else:
80
+ return self.host
81
+
82
+ @property
83
+ def aliases(self) -> FrozenSet[str]:
84
+ return frozenset(self._aliases)
85
+
86
+ @property
87
+ def all_aliases(self) -> FrozenSet[str]:
88
+ return frozenset(self._all_aliases)
89
+
90
+ def as_alias(self) -> str:
91
+ return f"{self.host}:{self.port}" if self.is_port_specified() else self.host
92
+
93
+ def add_alias(self, *aliases: str):
94
+ if not aliases:
95
+ return
96
+
97
+ for alias in aliases:
98
+ self._aliases.add(alias)
99
+ self._all_aliases.add(alias)
100
+
101
+ def as_aliases(self) -> FrozenSet[str]:
102
+ return frozenset(self.all_aliases)
103
+
104
+ def remove_alias(self, *kwargs):
105
+ if not kwargs or len(kwargs) == 0:
106
+ return
107
+
108
+ for x in kwargs:
109
+ self._aliases.discard(x)
110
+ self._all_aliases.discard(x)
111
+
112
+ def reset_aliases(self):
113
+ self._aliases.clear()
114
+ self._all_aliases.clear()
115
+ self._all_aliases.add(self.as_alias())
116
+
117
+ def is_port_specified(self) -> bool:
118
+ return self.port != HostInfo.NO_PORT
119
+
120
+ def get_availability(self) -> HostAvailability:
121
+ if self.host_availability_strategy is not None:
122
+ return self.host_availability_strategy.get_host_availability(self._availability)
123
+
124
+ return self._availability
125
+
126
+ def get_raw_availability(self) -> HostAvailability:
127
+ return self._availability
128
+
129
+ def set_availability(self, availability: HostAvailability):
130
+ self._availability = availability
131
+ if self.host_availability_strategy is not None:
132
+ self.host_availability_strategy.set_host_availability(availability)
133
+
134
+ def get_host_availability_strategy(self):
135
+ return self.host_availability_strategy
136
+
137
+ def set_host_availability_strategy(self, host_availability_strategy):
138
+ self.host_availability_strategy = host_availability_strategy
@@ -0,0 +1,195 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License").
4
+ # You may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import TYPE_CHECKING
18
+
19
+ from aws_advanced_python_wrapper.utils.iamutils import IamAuthUtils, TokenInfo
20
+
21
+ if TYPE_CHECKING:
22
+ from boto3 import Session
23
+ from aws_advanced_python_wrapper.driver_dialect import DriverDialect
24
+ from aws_advanced_python_wrapper.hostinfo import HostInfo
25
+ from aws_advanced_python_wrapper.pep249 import Connection
26
+ from aws_advanced_python_wrapper.plugin_service import PluginService
27
+
28
+ from datetime import datetime, timedelta
29
+ from typing import Callable, Dict, Optional, Set
30
+
31
+ import boto3
32
+
33
+ from aws_advanced_python_wrapper.errors import AwsWrapperError
34
+ from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
35
+ from aws_advanced_python_wrapper.utils.log import Logger
36
+ from aws_advanced_python_wrapper.utils.messages import Messages
37
+ from aws_advanced_python_wrapper.utils.properties import (Properties,
38
+ WrapperProperties)
39
+ from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
40
+ from aws_advanced_python_wrapper.utils.telemetry.telemetry import \
41
+ TelemetryTraceLevel
42
+
43
+ logger = Logger(__name__)
44
+
45
+
46
+ class IamAuthPlugin(Plugin):
47
+ _SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"}
48
+ # Leave 30 second buffer to prevent time-of-check to time-of-use errors
49
+ _DEFAULT_TOKEN_EXPIRATION_SEC = 15 * 60 - 30
50
+
51
+ _rds_utils: RdsUtils = RdsUtils()
52
+ _token_cache: Dict[str, TokenInfo] = {}
53
+
54
+ def __init__(self, plugin_service: PluginService, session: Optional[Session] = None):
55
+ self._plugin_service = plugin_service
56
+ self._session = session
57
+
58
+ telemetry_factory = self._plugin_service.get_telemetry_factory()
59
+ self._fetch_token_counter = telemetry_factory.create_counter("iam.fetch_token.count")
60
+ self._cache_size_gauge = telemetry_factory.create_gauge(
61
+ "iam.token_cache.size", lambda: len(IamAuthPlugin._token_cache))
62
+
63
+ @property
64
+ def subscribed_methods(self) -> Set[str]:
65
+ return self._SUBSCRIBED_METHODS
66
+
67
+ def connect(
68
+ self,
69
+ target_driver_func: Callable,
70
+ driver_dialect: DriverDialect,
71
+ host_info: HostInfo,
72
+ props: Properties,
73
+ is_initial_connection: bool,
74
+ connect_func: Callable) -> Connection:
75
+ return self._connect(host_info, props, connect_func)
76
+
77
+ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callable) -> Connection:
78
+ if not WrapperProperties.USER.get(props):
79
+ raise AwsWrapperError(Messages.get_formatted("IamPlugin.IsNoneOrEmpty", WrapperProperties.USER.name))
80
+
81
+ host = IamAuthUtils.get_iam_host(props, host_info)
82
+ region = WrapperProperties.IAM_REGION.get(props) \
83
+ if WrapperProperties.IAM_REGION.get(props) else self._get_rds_region(host)
84
+ port = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
85
+ token_expiration_sec: int = WrapperProperties.IAM_EXPIRATION.get_int(props)
86
+
87
+ cache_key: str = self._get_cache_key(
88
+ WrapperProperties.USER.get(props),
89
+ host,
90
+ port,
91
+ region
92
+ )
93
+
94
+ token_info = IamAuthPlugin._token_cache.get(cache_key)
95
+
96
+ if token_info is not None and not token_info.is_expired():
97
+ logger.debug("IamAuthPlugin.UseCachedIamToken", token_info.token)
98
+ self._plugin_service.driver_dialect.set_password(props, token_info.token)
99
+ else:
100
+ token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec)
101
+ token: str = self._generate_authentication_token(props, host, port, region)
102
+ logger.debug("IamAuthPlugin.GeneratedNewIamToken", token)
103
+ self._plugin_service.driver_dialect.set_password(props, token)
104
+ IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)
105
+
106
+ try:
107
+ return connect_func()
108
+
109
+ except Exception as e:
110
+ logger.debug("IamAuthPlugin.ConnectException", e)
111
+
112
+ is_cached_token = (token_info is not None and not token_info.is_expired())
113
+ if not self._plugin_service.is_login_exception(error=e) or not is_cached_token:
114
+ raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.ConnectException", e)) from e
115
+
116
+ # Login unsuccessful with cached token
117
+ # Try to generate a new token and try to connect again
118
+ token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec)
119
+ token = self._generate_authentication_token(props, host, port, region)
120
+ logger.debug("IamAuthPlugin.GeneratedNewIamToken", token)
121
+ self._plugin_service.driver_dialect.set_password(props, token)
122
+ IamAuthPlugin._token_cache[token] = TokenInfo(token, token_expiry)
123
+
124
+ try:
125
+ return connect_func()
126
+ except Exception as e:
127
+ raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.UnhandledException", e)) from e
128
+
129
+ def force_connect(
130
+ self,
131
+ target_driver_func: Callable,
132
+ driver_dialect: DriverDialect,
133
+ host_info: HostInfo,
134
+ props: Properties,
135
+ is_initial_connection: bool,
136
+ force_connect_func: Callable) -> Connection:
137
+ return self._connect(host_info, props, force_connect_func)
138
+
139
+ def _generate_authentication_token(self,
140
+ props: Properties,
141
+ hostname: Optional[str],
142
+ port: Optional[int],
143
+ region: Optional[str]) -> str:
144
+ telemetry_factory = self._plugin_service.get_telemetry_factory()
145
+ context = telemetry_factory.open_telemetry_context("fetch IAM token", TelemetryTraceLevel.NESTED)
146
+ self._fetch_token_counter.inc()
147
+
148
+ try:
149
+ session = self._session if self._session else boto3.Session()
150
+ client = session.client(
151
+ 'rds',
152
+ region_name=region,
153
+ )
154
+
155
+ user = WrapperProperties.USER.get(props)
156
+
157
+ token = client.generate_db_auth_token(
158
+ DBHostname=hostname,
159
+ Port=port,
160
+ DBUsername=user
161
+ )
162
+
163
+ client.close()
164
+
165
+ return token
166
+ except Exception as ex:
167
+ context.set_success(False)
168
+ context.set_exception(ex)
169
+ raise ex
170
+ finally:
171
+ context.close_context()
172
+
173
+ def _get_cache_key(self, user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str:
174
+ return f"{region}:{hostname}:{port}:{user}"
175
+
176
+ def _get_rds_region(self, hostname: Optional[str]) -> str:
177
+ rds_region = self._rds_utils.get_rds_region(hostname) if hostname else None
178
+
179
+ if not rds_region:
180
+ exception_message = "RdsUtils.UnsupportedHostname"
181
+ logger.debug(exception_message, hostname)
182
+ raise AwsWrapperError(Messages.get_formatted(exception_message, hostname))
183
+
184
+ session = self._session if self._session else boto3.Session()
185
+ if rds_region not in session.get_available_regions("rds"):
186
+ exception_message = "AwsSdk.UnsupportedRegion"
187
+ logger.debug(exception_message, rds_region)
188
+ raise AwsWrapperError(Messages.get_formatted(exception_message, rds_region))
189
+
190
+ return rds_region
191
+
192
+
193
+ class IamAuthPluginFactory(PluginFactory):
194
+ def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
195
+ return IamAuthPlugin(plugin_service)