redis 6.4.0__py3-none-any.whl → 7.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,785 @@
1
+ import enum
2
+ import ipaddress
3
+ import logging
4
+ import re
5
+ import threading
6
+ import time
7
+ from abc import ABC, abstractmethod
8
+ from typing import TYPE_CHECKING, Optional, Union
9
+
10
+ from redis.typing import Number
11
+
12
+
13
+ class MaintenanceState(enum.Enum):
14
+ NONE = "none"
15
+ MOVING = "moving"
16
+ MAINTENANCE = "maintenance"
17
+
18
+
19
+ class EndpointType(enum.Enum):
20
+ """Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command."""
21
+
22
+ INTERNAL_IP = "internal-ip"
23
+ INTERNAL_FQDN = "internal-fqdn"
24
+ EXTERNAL_IP = "external-ip"
25
+ EXTERNAL_FQDN = "external-fqdn"
26
+ NONE = "none"
27
+
28
+ def __str__(self):
29
+ """Return the string value of the enum."""
30
+ return self.value
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from redis.connection import (
35
+ BlockingConnectionPool,
36
+ ConnectionInterface,
37
+ ConnectionPool,
38
+ )
39
+
40
+
41
+ class MaintenanceEvent(ABC):
42
+ """
43
+ Base class for maintenance events sent through push messages by Redis server.
44
+
45
+ This class provides common functionality for all maintenance events including
46
+ unique identification and TTL (Time-To-Live) functionality.
47
+
48
+ Attributes:
49
+ id (int): Unique identifier for this event
50
+ ttl (int): Time-to-live in seconds for this notification
51
+ creation_time (float): Timestamp when the notification was created/read
52
+ """
53
+
54
+ def __init__(self, id: int, ttl: int):
55
+ """
56
+ Initialize a new MaintenanceEvent with unique ID and TTL functionality.
57
+
58
+ Args:
59
+ id (int): Unique identifier for this event
60
+ ttl (int): Time-to-live in seconds for this notification
61
+ """
62
+ self.id = id
63
+ self.ttl = ttl
64
+ self.creation_time = time.monotonic()
65
+ self.expire_at = self.creation_time + self.ttl
66
+
67
+ def is_expired(self) -> bool:
68
+ """
69
+ Check if this event has expired based on its TTL
70
+ and creation time.
71
+
72
+ Returns:
73
+ bool: True if the event has expired, False otherwise
74
+ """
75
+ return time.monotonic() > (self.creation_time + self.ttl)
76
+
77
+ @abstractmethod
78
+ def __repr__(self) -> str:
79
+ """
80
+ Return a string representation of the maintenance event.
81
+
82
+ This method must be implemented by all concrete subclasses.
83
+
84
+ Returns:
85
+ str: String representation of the event
86
+ """
87
+ pass
88
+
89
+ @abstractmethod
90
+ def __eq__(self, other) -> bool:
91
+ """
92
+ Compare two maintenance events for equality.
93
+
94
+ This method must be implemented by all concrete subclasses.
95
+ Events are typically considered equal if they have the same id
96
+ and are of the same type.
97
+
98
+ Args:
99
+ other: The other object to compare with
100
+
101
+ Returns:
102
+ bool: True if the events are equal, False otherwise
103
+ """
104
+ pass
105
+
106
+ @abstractmethod
107
+ def __hash__(self) -> int:
108
+ """
109
+ Return a hash value for the maintenance event.
110
+
111
+ This method must be implemented by all concrete subclasses to allow
112
+ instances to be used in sets and as dictionary keys.
113
+
114
+ Returns:
115
+ int: Hash value for the event
116
+ """
117
+ pass
118
+
119
+
120
+ class NodeMovingEvent(MaintenanceEvent):
121
+ """
122
+ This event is received when a node is replaced with a new node
123
+ during cluster rebalancing or maintenance operations.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ id: int,
129
+ new_node_host: Optional[str],
130
+ new_node_port: Optional[int],
131
+ ttl: int,
132
+ ):
133
+ """
134
+ Initialize a new NodeMovingEvent.
135
+
136
+ Args:
137
+ id (int): Unique identifier for this event
138
+ new_node_host (str): Hostname or IP address of the new replacement node
139
+ new_node_port (int): Port number of the new replacement node
140
+ ttl (int): Time-to-live in seconds for this notification
141
+ """
142
+ super().__init__(id, ttl)
143
+ self.new_node_host = new_node_host
144
+ self.new_node_port = new_node_port
145
+
146
+ def __repr__(self) -> str:
147
+ expiry_time = self.expire_at
148
+ remaining = max(0, expiry_time - time.monotonic())
149
+
150
+ return (
151
+ f"{self.__class__.__name__}("
152
+ f"id={self.id}, "
153
+ f"new_node_host='{self.new_node_host}', "
154
+ f"new_node_port={self.new_node_port}, "
155
+ f"ttl={self.ttl}, "
156
+ f"creation_time={self.creation_time}, "
157
+ f"expires_at={expiry_time}, "
158
+ f"remaining={remaining:.1f}s, "
159
+ f"expired={self.is_expired()}"
160
+ f")"
161
+ )
162
+
163
+ def __eq__(self, other) -> bool:
164
+ """
165
+ Two NodeMovingEvent events are considered equal if they have the same
166
+ id, new_node_host, and new_node_port.
167
+ """
168
+ if not isinstance(other, NodeMovingEvent):
169
+ return False
170
+ return (
171
+ self.id == other.id
172
+ and self.new_node_host == other.new_node_host
173
+ and self.new_node_port == other.new_node_port
174
+ )
175
+
176
+ def __hash__(self) -> int:
177
+ """
178
+ Return a hash value for the event to allow
179
+ instances to be used in sets and as dictionary keys.
180
+
181
+ Returns:
182
+ int: Hash value based on event type class name, id,
183
+ new_node_host and new_node_port
184
+ """
185
+ try:
186
+ node_port = int(self.new_node_port) if self.new_node_port else None
187
+ except ValueError:
188
+ node_port = 0
189
+
190
+ return hash(
191
+ (
192
+ self.__class__.__name__,
193
+ int(self.id),
194
+ str(self.new_node_host),
195
+ node_port,
196
+ )
197
+ )
198
+
199
+
200
+ class NodeMigratingEvent(MaintenanceEvent):
201
+ """
202
+ Event for when a Redis cluster node is in the process of migrating slots.
203
+
204
+ This event is received when a node starts migrating its slots to another node
205
+ during cluster rebalancing or maintenance operations.
206
+
207
+ Args:
208
+ id (int): Unique identifier for this event
209
+ ttl (int): Time-to-live in seconds for this notification
210
+ """
211
+
212
+ def __init__(self, id: int, ttl: int):
213
+ super().__init__(id, ttl)
214
+
215
+ def __repr__(self) -> str:
216
+ expiry_time = self.creation_time + self.ttl
217
+ remaining = max(0, expiry_time - time.monotonic())
218
+ return (
219
+ f"{self.__class__.__name__}("
220
+ f"id={self.id}, "
221
+ f"ttl={self.ttl}, "
222
+ f"creation_time={self.creation_time}, "
223
+ f"expires_at={expiry_time}, "
224
+ f"remaining={remaining:.1f}s, "
225
+ f"expired={self.is_expired()}"
226
+ f")"
227
+ )
228
+
229
+ def __eq__(self, other) -> bool:
230
+ """
231
+ Two NodeMigratingEvent events are considered equal if they have the same
232
+ id and are of the same type.
233
+ """
234
+ if not isinstance(other, NodeMigratingEvent):
235
+ return False
236
+ return self.id == other.id and type(self) is type(other)
237
+
238
+ def __hash__(self) -> int:
239
+ """
240
+ Return a hash value for the event to allow
241
+ instances to be used in sets and as dictionary keys.
242
+
243
+ Returns:
244
+ int: Hash value based on event type and id
245
+ """
246
+ return hash((self.__class__.__name__, int(self.id)))
247
+
248
+
249
+ class NodeMigratedEvent(MaintenanceEvent):
250
+ """
251
+ Event for when a Redis cluster node has completed migrating slots.
252
+
253
+ This event is received when a node has finished migrating all its slots
254
+ to other nodes during cluster rebalancing or maintenance operations.
255
+
256
+ Args:
257
+ id (int): Unique identifier for this event
258
+ """
259
+
260
+ DEFAULT_TTL = 5
261
+
262
+ def __init__(self, id: int):
263
+ super().__init__(id, NodeMigratedEvent.DEFAULT_TTL)
264
+
265
+ def __repr__(self) -> str:
266
+ expiry_time = self.creation_time + self.ttl
267
+ remaining = max(0, expiry_time - time.monotonic())
268
+ return (
269
+ f"{self.__class__.__name__}("
270
+ f"id={self.id}, "
271
+ f"ttl={self.ttl}, "
272
+ f"creation_time={self.creation_time}, "
273
+ f"expires_at={expiry_time}, "
274
+ f"remaining={remaining:.1f}s, "
275
+ f"expired={self.is_expired()}"
276
+ f")"
277
+ )
278
+
279
+ def __eq__(self, other) -> bool:
280
+ """
281
+ Two NodeMigratedEvent events are considered equal if they have the same
282
+ id and are of the same type.
283
+ """
284
+ if not isinstance(other, NodeMigratedEvent):
285
+ return False
286
+ return self.id == other.id and type(self) is type(other)
287
+
288
+ def __hash__(self) -> int:
289
+ """
290
+ Return a hash value for the event to allow
291
+ instances to be used in sets and as dictionary keys.
292
+
293
+ Returns:
294
+ int: Hash value based on event type and id
295
+ """
296
+ return hash((self.__class__.__name__, int(self.id)))
297
+
298
+
299
+ class NodeFailingOverEvent(MaintenanceEvent):
300
+ """
301
+ Event for when a Redis cluster node is in the process of failing over.
302
+
303
+ This event is received when a node starts a failover process during
304
+ cluster maintenance operations or when handling node failures.
305
+
306
+ Args:
307
+ id (int): Unique identifier for this event
308
+ ttl (int): Time-to-live in seconds for this notification
309
+ """
310
+
311
+ def __init__(self, id: int, ttl: int):
312
+ super().__init__(id, ttl)
313
+
314
+ def __repr__(self) -> str:
315
+ expiry_time = self.creation_time + self.ttl
316
+ remaining = max(0, expiry_time - time.monotonic())
317
+ return (
318
+ f"{self.__class__.__name__}("
319
+ f"id={self.id}, "
320
+ f"ttl={self.ttl}, "
321
+ f"creation_time={self.creation_time}, "
322
+ f"expires_at={expiry_time}, "
323
+ f"remaining={remaining:.1f}s, "
324
+ f"expired={self.is_expired()}"
325
+ f")"
326
+ )
327
+
328
+ def __eq__(self, other) -> bool:
329
+ """
330
+ Two NodeFailingOverEvent events are considered equal if they have the same
331
+ id and are of the same type.
332
+ """
333
+ if not isinstance(other, NodeFailingOverEvent):
334
+ return False
335
+ return self.id == other.id and type(self) is type(other)
336
+
337
+ def __hash__(self) -> int:
338
+ """
339
+ Return a hash value for the event to allow
340
+ instances to be used in sets and as dictionary keys.
341
+
342
+ Returns:
343
+ int: Hash value based on event type and id
344
+ """
345
+ return hash((self.__class__.__name__, int(self.id)))
346
+
347
+
348
+ class NodeFailedOverEvent(MaintenanceEvent):
349
+ """
350
+ Event for when a Redis cluster node has completed a failover.
351
+
352
+ This event is received when a node has finished the failover process
353
+ during cluster maintenance operations or after handling node failures.
354
+
355
+ Args:
356
+ id (int): Unique identifier for this event
357
+ """
358
+
359
+ DEFAULT_TTL = 5
360
+
361
+ def __init__(self, id: int):
362
+ super().__init__(id, NodeFailedOverEvent.DEFAULT_TTL)
363
+
364
+ def __repr__(self) -> str:
365
+ expiry_time = self.creation_time + self.ttl
366
+ remaining = max(0, expiry_time - time.monotonic())
367
+ return (
368
+ f"{self.__class__.__name__}("
369
+ f"id={self.id}, "
370
+ f"ttl={self.ttl}, "
371
+ f"creation_time={self.creation_time}, "
372
+ f"expires_at={expiry_time}, "
373
+ f"remaining={remaining:.1f}s, "
374
+ f"expired={self.is_expired()}"
375
+ f")"
376
+ )
377
+
378
+ def __eq__(self, other) -> bool:
379
+ """
380
+ Two NodeFailedOverEvent events are considered equal if they have the same
381
+ id and are of the same type.
382
+ """
383
+ if not isinstance(other, NodeFailedOverEvent):
384
+ return False
385
+ return self.id == other.id and type(self) is type(other)
386
+
387
+ def __hash__(self) -> int:
388
+ """
389
+ Return a hash value for the event to allow
390
+ instances to be used in sets and as dictionary keys.
391
+
392
+ Returns:
393
+ int: Hash value based on event type and id
394
+ """
395
+ return hash((self.__class__.__name__, int(self.id)))
396
+
397
+
398
+ def _is_private_fqdn(host: str) -> bool:
399
+ """
400
+ Determine if an FQDN is likely to be internal/private.
401
+
402
+ This uses heuristics based on RFC 952 and RFC 1123 standards:
403
+ - .local domains (RFC 6762 - Multicast DNS)
404
+ - .internal domains (common internal convention)
405
+ - Single-label hostnames (no dots)
406
+ - Common internal TLDs
407
+
408
+ Args:
409
+ host (str): The FQDN to check
410
+
411
+ Returns:
412
+ bool: True if the FQDN appears to be internal/private
413
+ """
414
+ host_lower = host.lower().rstrip(".")
415
+
416
+ # Single-label hostnames (no dots) are typically internal
417
+ if "." not in host_lower:
418
+ return True
419
+
420
+ # Common internal/private domain patterns
421
+ internal_patterns = [
422
+ r"\.local$", # mDNS/Bonjour domains
423
+ r"\.internal$", # Common internal convention
424
+ r"\.corp$", # Corporate domains
425
+ r"\.lan$", # Local area network
426
+ r"\.intranet$", # Intranet domains
427
+ r"\.private$", # Private domains
428
+ ]
429
+
430
+ for pattern in internal_patterns:
431
+ if re.search(pattern, host_lower):
432
+ return True
433
+
434
+ # If none of the internal patterns match, assume it's external
435
+ return False
436
+
437
+
438
+ class MaintenanceEventsConfig:
439
+ """
440
+ Configuration class for maintenance events handling behaviour. Events are received through
441
+ push notifications.
442
+
443
+ This class defines how the Redis client should react to different push notifications
444
+ such as node moving, migrations, etc. in a Redis cluster.
445
+
446
+ """
447
+
448
+ def __init__(
449
+ self,
450
+ enabled: bool = True,
451
+ proactive_reconnect: bool = True,
452
+ relax_timeout: Optional[Number] = 10,
453
+ endpoint_type: Optional[EndpointType] = None,
454
+ ):
455
+ """
456
+ Initialize a new MaintenanceEventsConfig.
457
+
458
+ Args:
459
+ enabled (bool): Whether to enable maintenance events handling.
460
+ Defaults to False.
461
+ proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced.
462
+ Defaults to True.
463
+ relax_timeout (Number): The relax timeout to use for the connection during maintenance.
464
+ If -1 is provided - the relax timeout is disabled. Defaults to 20.
465
+ endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS.
466
+ If None, the endpoint type will be automatically determined based on the host and TLS configuration.
467
+ Defaults to None.
468
+
469
+ Raises:
470
+ ValueError: If endpoint_type is provided but is not a valid endpoint type.
471
+ """
472
+ self.enabled = enabled
473
+ self.relax_timeout = relax_timeout
474
+ self.proactive_reconnect = proactive_reconnect
475
+ self.endpoint_type = endpoint_type
476
+
477
+ def __repr__(self) -> str:
478
+ return (
479
+ f"{self.__class__.__name__}("
480
+ f"enabled={self.enabled}, "
481
+ f"proactive_reconnect={self.proactive_reconnect}, "
482
+ f"relax_timeout={self.relax_timeout}, "
483
+ f"endpoint_type={self.endpoint_type!r}"
484
+ f")"
485
+ )
486
+
487
+ def is_relax_timeouts_enabled(self) -> bool:
488
+ """
489
+ Check if the relax_timeout is enabled. The '-1' value is used to disable the relax_timeout.
490
+ If relax_timeout is set to None, it will make the operation blocking
491
+ and waiting until any response is received.
492
+
493
+ Returns:
494
+ True if the relax_timeout is enabled, False otherwise.
495
+ """
496
+ return self.relax_timeout != -1
497
+
498
+ def get_endpoint_type(
499
+ self, host: str, connection: "ConnectionInterface"
500
+ ) -> EndpointType:
501
+ """
502
+ Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command.
503
+
504
+ Logic:
505
+ 1. If endpoint_type is explicitly set, use it
506
+ 2. Otherwise, check the original host from connection.host:
507
+ - If host is an IP address, use it directly to determine internal-ip vs external-ip
508
+ - If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn
509
+
510
+ Args:
511
+ host: User provided hostname to analyze
512
+ connection: The connection object to analyze for endpoint type determination
513
+
514
+ Returns:
515
+ """
516
+
517
+ # If endpoint_type is explicitly set, use it
518
+ if self.endpoint_type is not None:
519
+ return self.endpoint_type
520
+
521
+ # Check if the host is an IP address
522
+ try:
523
+ ip_addr = ipaddress.ip_address(host)
524
+ # Host is an IP address - use it directly
525
+ is_private = ip_addr.is_private
526
+ return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP
527
+ except ValueError:
528
+ # Host is an FQDN - need to check resolved IP to determine internal vs external
529
+ pass
530
+
531
+ # Host is an FQDN, get the resolved IP to determine if it's internal or external
532
+ resolved_ip = connection.get_resolved_ip()
533
+
534
+ if resolved_ip:
535
+ try:
536
+ ip_addr = ipaddress.ip_address(resolved_ip)
537
+ is_private = ip_addr.is_private
538
+ # Use FQDN types since the original host was an FQDN
539
+ return (
540
+ EndpointType.INTERNAL_FQDN
541
+ if is_private
542
+ else EndpointType.EXTERNAL_FQDN
543
+ )
544
+ except ValueError:
545
+ # This shouldn't happen since we got the IP from the socket, but fallback
546
+ pass
547
+
548
+ # Final fallback: use heuristics on the FQDN itself
549
+ is_private = _is_private_fqdn(host)
550
+ return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN
551
+
552
+
553
+ class MaintenanceEventPoolHandler:
554
+ def __init__(
555
+ self,
556
+ pool: Union["ConnectionPool", "BlockingConnectionPool"],
557
+ config: MaintenanceEventsConfig,
558
+ ) -> None:
559
+ self.pool = pool
560
+ self.config = config
561
+ self._processed_events = set()
562
+ self._lock = threading.RLock()
563
+ self.connection = None
564
+
565
+ def set_connection(self, connection: "ConnectionInterface"):
566
+ self.connection = connection
567
+
568
+ def remove_expired_notifications(self):
569
+ with self._lock:
570
+ for notification in tuple(self._processed_events):
571
+ if notification.is_expired():
572
+ self._processed_events.remove(notification)
573
+
574
+ def handle_event(self, notification: MaintenanceEvent):
575
+ self.remove_expired_notifications()
576
+
577
+ if isinstance(notification, NodeMovingEvent):
578
+ return self.handle_node_moving_event(notification)
579
+ else:
580
+ logging.error(f"Unhandled notification type: {notification}")
581
+
582
+ def handle_node_moving_event(self, event: NodeMovingEvent):
583
+ if (
584
+ not self.config.proactive_reconnect
585
+ and not self.config.is_relax_timeouts_enabled()
586
+ ):
587
+ return
588
+ with self._lock:
589
+ if event in self._processed_events:
590
+ # nothing to do in the connection pool handling
591
+ # the event has already been handled or is expired
592
+ # just return
593
+ return
594
+
595
+ with self.pool._lock:
596
+ if (
597
+ self.config.proactive_reconnect
598
+ or self.config.is_relax_timeouts_enabled()
599
+ ):
600
+ # Get the current connected address - if any
601
+ # This is the address that is being moved
602
+ # and we need to handle only connections
603
+ # connected to the same address
604
+ moving_address_src = (
605
+ self.connection.getpeername() if self.connection else None
606
+ )
607
+
608
+ if getattr(self.pool, "set_in_maintenance", False):
609
+ # Set pool in maintenance mode - executed only if
610
+ # BlockingConnectionPool is used
611
+ self.pool.set_in_maintenance(True)
612
+
613
+ # Update maintenance state, timeout and optionally host address
614
+ # connection settings for matching connections
615
+ self.pool.update_connections_settings(
616
+ state=MaintenanceState.MOVING,
617
+ maintenance_event_hash=hash(event),
618
+ relax_timeout=self.config.relax_timeout,
619
+ host_address=event.new_node_host,
620
+ matching_address=moving_address_src,
621
+ matching_pattern="connected_address",
622
+ update_event_hash=True,
623
+ include_free_connections=True,
624
+ )
625
+
626
+ if self.config.proactive_reconnect:
627
+ if event.new_node_host is not None:
628
+ self.run_proactive_reconnect(moving_address_src)
629
+ else:
630
+ threading.Timer(
631
+ event.ttl / 2,
632
+ self.run_proactive_reconnect,
633
+ args=(moving_address_src,),
634
+ ).start()
635
+
636
+ # Update config for new connections:
637
+ # Set state to MOVING
638
+ # update host
639
+ # if relax timeouts are enabled - update timeouts
640
+ kwargs: dict = {
641
+ "maintenance_state": MaintenanceState.MOVING,
642
+ "maintenance_event_hash": hash(event),
643
+ }
644
+ if event.new_node_host is not None:
645
+ # the host is not updated if the new node host is None
646
+ # this happens when the MOVING push notification does not contain
647
+ # the new node host - in this case we only update the timeouts
648
+ kwargs.update(
649
+ {
650
+ "host": event.new_node_host,
651
+ }
652
+ )
653
+ if self.config.is_relax_timeouts_enabled():
654
+ kwargs.update(
655
+ {
656
+ "socket_timeout": self.config.relax_timeout,
657
+ "socket_connect_timeout": self.config.relax_timeout,
658
+ }
659
+ )
660
+ self.pool.update_connection_kwargs(**kwargs)
661
+
662
+ if getattr(self.pool, "set_in_maintenance", False):
663
+ self.pool.set_in_maintenance(False)
664
+
665
+ threading.Timer(
666
+ event.ttl, self.handle_node_moved_event, args=(event,)
667
+ ).start()
668
+
669
+ self._processed_events.add(event)
670
+
671
+ def run_proactive_reconnect(self, moving_address_src: Optional[str] = None):
672
+ """
673
+ Run proactive reconnect for the pool.
674
+ Active connections are marked for reconnect after they complete the current command.
675
+ Inactive connections are disconnected and will be connected on next use.
676
+ """
677
+ with self._lock:
678
+ with self.pool._lock:
679
+ # take care for the active connections in the pool
680
+ # mark them for reconnect after they complete the current command
681
+ self.pool.update_active_connections_for_reconnect(
682
+ moving_address_src=moving_address_src,
683
+ )
684
+ # take care for the inactive connections in the pool
685
+ # delete them and create new ones
686
+ self.pool.disconnect_free_connections(
687
+ moving_address_src=moving_address_src,
688
+ )
689
+
690
+ def handle_node_moved_event(self, event: NodeMovingEvent):
691
+ """
692
+ Handle the cleanup after a node moving event expires.
693
+ """
694
+ event_hash = hash(event)
695
+
696
+ with self._lock:
697
+ # if the current maintenance_event_hash in kwargs is not matching the event
698
+ # it means there has been a new moving event after this one
699
+ # and we don't need to revert the kwargs yet
700
+ if self.pool.connection_kwargs.get("maintenance_event_hash") == event_hash:
701
+ orig_host = self.pool.connection_kwargs.get("orig_host_address")
702
+ orig_socket_timeout = self.pool.connection_kwargs.get(
703
+ "orig_socket_timeout"
704
+ )
705
+ orig_connect_timeout = self.pool.connection_kwargs.get(
706
+ "orig_socket_connect_timeout"
707
+ )
708
+ kwargs: dict = {
709
+ "maintenance_state": MaintenanceState.NONE,
710
+ "maintenance_event_hash": None,
711
+ "host": orig_host,
712
+ "socket_timeout": orig_socket_timeout,
713
+ "socket_connect_timeout": orig_connect_timeout,
714
+ }
715
+ self.pool.update_connection_kwargs(**kwargs)
716
+
717
+ with self.pool._lock:
718
+ reset_relax_timeout = self.config.is_relax_timeouts_enabled()
719
+ reset_host_address = self.config.proactive_reconnect
720
+
721
+ self.pool.update_connections_settings(
722
+ relax_timeout=-1,
723
+ state=MaintenanceState.NONE,
724
+ maintenance_event_hash=None,
725
+ matching_event_hash=event_hash,
726
+ matching_pattern="event_hash",
727
+ update_event_hash=True,
728
+ reset_relax_timeout=reset_relax_timeout,
729
+ reset_host_address=reset_host_address,
730
+ include_free_connections=True,
731
+ )
732
+
733
+
734
+ class MaintenanceEventConnectionHandler:
735
+ # 1 = "starting maintenance" events, 0 = "completed maintenance" events
736
+ _EVENT_TYPES: dict[type["MaintenanceEvent"], int] = {
737
+ NodeMigratingEvent: 1,
738
+ NodeFailingOverEvent: 1,
739
+ NodeMigratedEvent: 0,
740
+ NodeFailedOverEvent: 0,
741
+ }
742
+
743
+ def __init__(
744
+ self, connection: "ConnectionInterface", config: MaintenanceEventsConfig
745
+ ) -> None:
746
+ self.connection = connection
747
+ self.config = config
748
+
749
+ def handle_event(self, event: MaintenanceEvent):
750
+ # get the event type by checking its class in the _EVENT_TYPES dict
751
+ event_type = self._EVENT_TYPES.get(event.__class__, None)
752
+
753
+ if event_type is None:
754
+ logging.error(f"Unhandled event type: {event}")
755
+ return
756
+
757
+ if event_type:
758
+ self.handle_maintenance_start_event(MaintenanceState.MAINTENANCE)
759
+ else:
760
+ self.handle_maintenance_completed_event()
761
+
762
+ def handle_maintenance_start_event(self, maintenance_state: MaintenanceState):
763
+ if (
764
+ self.connection.maintenance_state == MaintenanceState.MOVING
765
+ or not self.config.is_relax_timeouts_enabled()
766
+ ):
767
+ return
768
+
769
+ self.connection.maintenance_state = maintenance_state
770
+ self.connection.set_tmp_settings(tmp_relax_timeout=self.config.relax_timeout)
771
+ # extend the timeout for all created connections
772
+ self.connection.update_current_socket_timeout(self.config.relax_timeout)
773
+
774
+ def handle_maintenance_completed_event(self):
775
+ # Only reset timeouts if state is not MOVING and relax timeouts are enabled
776
+ if (
777
+ self.connection.maintenance_state == MaintenanceState.MOVING
778
+ or not self.config.is_relax_timeouts_enabled()
779
+ ):
780
+ return
781
+ self.connection.reset_tmp_settings(reset_relax_timeout=True)
782
+ # Maintenance completed - reset the connection
783
+ # timeouts by providing -1 as the relax timeout
784
+ self.connection.update_current_socket_timeout(-1)
785
+ self.connection.maintenance_state = MaintenanceState.NONE