redis 6.4.0__py3-none-any.whl → 7.0.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,799 @@
1
+ import enum
2
+ import ipaddress
3
+ import logging
4
+ import re
5
+ import threading
6
+ import time
7
+ from abc import ABC, abstractmethod
8
+ from typing import TYPE_CHECKING, Literal, Optional, Union
9
+
10
+ from redis.typing import Number
11
+
12
+
13
+ class MaintenanceState(enum.Enum):
14
+ NONE = "none"
15
+ MOVING = "moving"
16
+ MAINTENANCE = "maintenance"
17
+
18
+
19
+ class EndpointType(enum.Enum):
20
+ """Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command."""
21
+
22
+ INTERNAL_IP = "internal-ip"
23
+ INTERNAL_FQDN = "internal-fqdn"
24
+ EXTERNAL_IP = "external-ip"
25
+ EXTERNAL_FQDN = "external-fqdn"
26
+ NONE = "none"
27
+
28
+ def __str__(self):
29
+ """Return the string value of the enum."""
30
+ return self.value
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from redis.connection import (
35
+ BlockingConnectionPool,
36
+ ConnectionInterface,
37
+ ConnectionPool,
38
+ )
39
+
40
+
41
+ class MaintenanceNotification(ABC):
42
+ """
43
+ Base class for maintenance notifications sent through push messages by Redis server.
44
+
45
+ This class provides common functionality for all maintenance notifications including
46
+ unique identification and TTL (Time-To-Live) functionality.
47
+
48
+ Attributes:
49
+ id (int): Unique identifier for this notification
50
+ ttl (int): Time-to-live in seconds for this notification
51
+ creation_time (float): Timestamp when the notification was created/read
52
+ """
53
+
54
+ def __init__(self, id: int, ttl: int):
55
+ """
56
+ Initialize a new MaintenanceNotification with unique ID and TTL functionality.
57
+
58
+ Args:
59
+ id (int): Unique identifier for this notification
60
+ ttl (int): Time-to-live in seconds for this notification
61
+ """
62
+ self.id = id
63
+ self.ttl = ttl
64
+ self.creation_time = time.monotonic()
65
+ self.expire_at = self.creation_time + self.ttl
66
+
67
+ def is_expired(self) -> bool:
68
+ """
69
+ Check if this notification has expired based on its TTL
70
+ and creation time.
71
+
72
+ Returns:
73
+ bool: True if the notification has expired, False otherwise
74
+ """
75
+ return time.monotonic() > (self.creation_time + self.ttl)
76
+
77
+ @abstractmethod
78
+ def __repr__(self) -> str:
79
+ """
80
+ Return a string representation of the maintenance notification.
81
+
82
+ This method must be implemented by all concrete subclasses.
83
+
84
+ Returns:
85
+ str: String representation of the notification
86
+ """
87
+ pass
88
+
89
+ @abstractmethod
90
+ def __eq__(self, other) -> bool:
91
+ """
92
+ Compare two maintenance notifications for equality.
93
+
94
+ This method must be implemented by all concrete subclasses.
95
+ Notifications are typically considered equal if they have the same id
96
+ and are of the same type.
97
+
98
+ Args:
99
+ other: The other object to compare with
100
+
101
+ Returns:
102
+ bool: True if the notifications are equal, False otherwise
103
+ """
104
+ pass
105
+
106
+ @abstractmethod
107
+ def __hash__(self) -> int:
108
+ """
109
+ Return a hash value for the maintenance notification.
110
+
111
+ This method must be implemented by all concrete subclasses to allow
112
+ instances to be used in sets and as dictionary keys.
113
+
114
+ Returns:
115
+ int: Hash value for the notification
116
+ """
117
+ pass
118
+
119
+
120
+ class NodeMovingNotification(MaintenanceNotification):
121
+ """
122
+ This notification is received when a node is replaced with a new node
123
+ during cluster rebalancing or maintenance operations.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ id: int,
129
+ new_node_host: Optional[str],
130
+ new_node_port: Optional[int],
131
+ ttl: int,
132
+ ):
133
+ """
134
+ Initialize a new NodeMovingNotification.
135
+
136
+ Args:
137
+ id (int): Unique identifier for this notification
138
+ new_node_host (str): Hostname or IP address of the new replacement node
139
+ new_node_port (int): Port number of the new replacement node
140
+ ttl (int): Time-to-live in seconds for this notification
141
+ """
142
+ super().__init__(id, ttl)
143
+ self.new_node_host = new_node_host
144
+ self.new_node_port = new_node_port
145
+
146
+ def __repr__(self) -> str:
147
+ expiry_time = self.expire_at
148
+ remaining = max(0, expiry_time - time.monotonic())
149
+
150
+ return (
151
+ f"{self.__class__.__name__}("
152
+ f"id={self.id}, "
153
+ f"new_node_host='{self.new_node_host}', "
154
+ f"new_node_port={self.new_node_port}, "
155
+ f"ttl={self.ttl}, "
156
+ f"creation_time={self.creation_time}, "
157
+ f"expires_at={expiry_time}, "
158
+ f"remaining={remaining:.1f}s, "
159
+ f"expired={self.is_expired()}"
160
+ f")"
161
+ )
162
+
163
+ def __eq__(self, other) -> bool:
164
+ """
165
+ Two NodeMovingNotification notifications are considered equal if they have the same
166
+ id, new_node_host, and new_node_port.
167
+ """
168
+ if not isinstance(other, NodeMovingNotification):
169
+ return False
170
+ return (
171
+ self.id == other.id
172
+ and self.new_node_host == other.new_node_host
173
+ and self.new_node_port == other.new_node_port
174
+ )
175
+
176
+ def __hash__(self) -> int:
177
+ """
178
+ Return a hash value for the notification to allow
179
+ instances to be used in sets and as dictionary keys.
180
+
181
+ Returns:
182
+ int: Hash value based on notification type class name, id,
183
+ new_node_host and new_node_port
184
+ """
185
+ try:
186
+ node_port = int(self.new_node_port) if self.new_node_port else None
187
+ except ValueError:
188
+ node_port = 0
189
+
190
+ return hash(
191
+ (
192
+ self.__class__.__name__,
193
+ int(self.id),
194
+ str(self.new_node_host),
195
+ node_port,
196
+ )
197
+ )
198
+
199
+
200
+ class NodeMigratingNotification(MaintenanceNotification):
201
+ """
202
+ Notification for when a Redis cluster node is in the process of migrating slots.
203
+
204
+ This notification is received when a node starts migrating its slots to another node
205
+ during cluster rebalancing or maintenance operations.
206
+
207
+ Args:
208
+ id (int): Unique identifier for this notification
209
+ ttl (int): Time-to-live in seconds for this notification
210
+ """
211
+
212
+ def __init__(self, id: int, ttl: int):
213
+ super().__init__(id, ttl)
214
+
215
+ def __repr__(self) -> str:
216
+ expiry_time = self.creation_time + self.ttl
217
+ remaining = max(0, expiry_time - time.monotonic())
218
+ return (
219
+ f"{self.__class__.__name__}("
220
+ f"id={self.id}, "
221
+ f"ttl={self.ttl}, "
222
+ f"creation_time={self.creation_time}, "
223
+ f"expires_at={expiry_time}, "
224
+ f"remaining={remaining:.1f}s, "
225
+ f"expired={self.is_expired()}"
226
+ f")"
227
+ )
228
+
229
+ def __eq__(self, other) -> bool:
230
+ """
231
+ Two NodeMigratingNotification notifications are considered equal if they have the same
232
+ id and are of the same type.
233
+ """
234
+ if not isinstance(other, NodeMigratingNotification):
235
+ return False
236
+ return self.id == other.id and type(self) is type(other)
237
+
238
+ def __hash__(self) -> int:
239
+ """
240
+ Return a hash value for the notification to allow
241
+ instances to be used in sets and as dictionary keys.
242
+
243
+ Returns:
244
+ int: Hash value based on notification type and id
245
+ """
246
+ return hash((self.__class__.__name__, int(self.id)))
247
+
248
+
249
+ class NodeMigratedNotification(MaintenanceNotification):
250
+ """
251
+ Notification for when a Redis cluster node has completed migrating slots.
252
+
253
+ This notification is received when a node has finished migrating all its slots
254
+ to other nodes during cluster rebalancing or maintenance operations.
255
+
256
+ Args:
257
+ id (int): Unique identifier for this notification
258
+ """
259
+
260
+ DEFAULT_TTL = 5
261
+
262
+ def __init__(self, id: int):
263
+ super().__init__(id, NodeMigratedNotification.DEFAULT_TTL)
264
+
265
+ def __repr__(self) -> str:
266
+ expiry_time = self.creation_time + self.ttl
267
+ remaining = max(0, expiry_time - time.monotonic())
268
+ return (
269
+ f"{self.__class__.__name__}("
270
+ f"id={self.id}, "
271
+ f"ttl={self.ttl}, "
272
+ f"creation_time={self.creation_time}, "
273
+ f"expires_at={expiry_time}, "
274
+ f"remaining={remaining:.1f}s, "
275
+ f"expired={self.is_expired()}"
276
+ f")"
277
+ )
278
+
279
+ def __eq__(self, other) -> bool:
280
+ """
281
+ Two NodeMigratedNotification notifications are considered equal if they have the same
282
+ id and are of the same type.
283
+ """
284
+ if not isinstance(other, NodeMigratedNotification):
285
+ return False
286
+ return self.id == other.id and type(self) is type(other)
287
+
288
+ def __hash__(self) -> int:
289
+ """
290
+ Return a hash value for the notification to allow
291
+ instances to be used in sets and as dictionary keys.
292
+
293
+ Returns:
294
+ int: Hash value based on notification type and id
295
+ """
296
+ return hash((self.__class__.__name__, int(self.id)))
297
+
298
+
299
+ class NodeFailingOverNotification(MaintenanceNotification):
300
+ """
301
+ Notification for when a Redis cluster node is in the process of failing over.
302
+
303
+ This notification is received when a node starts a failover process during
304
+ cluster maintenance operations or when handling node failures.
305
+
306
+ Args:
307
+ id (int): Unique identifier for this notification
308
+ ttl (int): Time-to-live in seconds for this notification
309
+ """
310
+
311
+ def __init__(self, id: int, ttl: int):
312
+ super().__init__(id, ttl)
313
+
314
+ def __repr__(self) -> str:
315
+ expiry_time = self.creation_time + self.ttl
316
+ remaining = max(0, expiry_time - time.monotonic())
317
+ return (
318
+ f"{self.__class__.__name__}("
319
+ f"id={self.id}, "
320
+ f"ttl={self.ttl}, "
321
+ f"creation_time={self.creation_time}, "
322
+ f"expires_at={expiry_time}, "
323
+ f"remaining={remaining:.1f}s, "
324
+ f"expired={self.is_expired()}"
325
+ f")"
326
+ )
327
+
328
+ def __eq__(self, other) -> bool:
329
+ """
330
+ Two NodeFailingOverNotification notifications are considered equal if they have the same
331
+ id and are of the same type.
332
+ """
333
+ if not isinstance(other, NodeFailingOverNotification):
334
+ return False
335
+ return self.id == other.id and type(self) is type(other)
336
+
337
+ def __hash__(self) -> int:
338
+ """
339
+ Return a hash value for the notification to allow
340
+ instances to be used in sets and as dictionary keys.
341
+
342
+ Returns:
343
+ int: Hash value based on notification type and id
344
+ """
345
+ return hash((self.__class__.__name__, int(self.id)))
346
+
347
+
348
+ class NodeFailedOverNotification(MaintenanceNotification):
349
+ """
350
+ Notification for when a Redis cluster node has completed a failover.
351
+
352
+ This notification is received when a node has finished the failover process
353
+ during cluster maintenance operations or after handling node failures.
354
+
355
+ Args:
356
+ id (int): Unique identifier for this notification
357
+ """
358
+
359
+ DEFAULT_TTL = 5
360
+
361
+ def __init__(self, id: int):
362
+ super().__init__(id, NodeFailedOverNotification.DEFAULT_TTL)
363
+
364
+ def __repr__(self) -> str:
365
+ expiry_time = self.creation_time + self.ttl
366
+ remaining = max(0, expiry_time - time.monotonic())
367
+ return (
368
+ f"{self.__class__.__name__}("
369
+ f"id={self.id}, "
370
+ f"ttl={self.ttl}, "
371
+ f"creation_time={self.creation_time}, "
372
+ f"expires_at={expiry_time}, "
373
+ f"remaining={remaining:.1f}s, "
374
+ f"expired={self.is_expired()}"
375
+ f")"
376
+ )
377
+
378
+ def __eq__(self, other) -> bool:
379
+ """
380
+ Two NodeFailedOverNotification notifications are considered equal if they have the same
381
+ id and are of the same type.
382
+ """
383
+ if not isinstance(other, NodeFailedOverNotification):
384
+ return False
385
+ return self.id == other.id and type(self) is type(other)
386
+
387
+ def __hash__(self) -> int:
388
+ """
389
+ Return a hash value for the notification to allow
390
+ instances to be used in sets and as dictionary keys.
391
+
392
+ Returns:
393
+ int: Hash value based on notification type and id
394
+ """
395
+ return hash((self.__class__.__name__, int(self.id)))
396
+
397
+
398
+ def _is_private_fqdn(host: str) -> bool:
399
+ """
400
+ Determine if an FQDN is likely to be internal/private.
401
+
402
+ This uses heuristics based on RFC 952 and RFC 1123 standards:
403
+ - .local domains (RFC 6762 - Multicast DNS)
404
+ - .internal domains (common internal convention)
405
+ - Single-label hostnames (no dots)
406
+ - Common internal TLDs
407
+
408
+ Args:
409
+ host (str): The FQDN to check
410
+
411
+ Returns:
412
+ bool: True if the FQDN appears to be internal/private
413
+ """
414
+ host_lower = host.lower().rstrip(".")
415
+
416
+ # Single-label hostnames (no dots) are typically internal
417
+ if "." not in host_lower:
418
+ return True
419
+
420
+ # Common internal/private domain patterns
421
+ internal_patterns = [
422
+ r"\.local$", # mDNS/Bonjour domains
423
+ r"\.internal$", # Common internal convention
424
+ r"\.corp$", # Corporate domains
425
+ r"\.lan$", # Local area network
426
+ r"\.intranet$", # Intranet domains
427
+ r"\.private$", # Private domains
428
+ ]
429
+
430
+ for pattern in internal_patterns:
431
+ if re.search(pattern, host_lower):
432
+ return True
433
+
434
+ # If none of the internal patterns match, assume it's external
435
+ return False
436
+
437
+
438
+ class MaintNotificationsConfig:
439
+ """
440
+ Configuration class for maintenance notifications handling behaviour. Notifications are received through
441
+ push notifications.
442
+
443
+ This class defines how the Redis client should react to different push notifications
444
+ such as node moving, migrations, etc. in a Redis cluster.
445
+
446
+ """
447
+
448
+ def __init__(
449
+ self,
450
+ enabled: Union[bool, Literal["auto"]] = "auto",
451
+ proactive_reconnect: bool = True,
452
+ relaxed_timeout: Optional[Number] = 10,
453
+ endpoint_type: Optional[EndpointType] = None,
454
+ ):
455
+ """
456
+ Initialize a new MaintNotificationsConfig.
457
+
458
+ Args:
459
+ enabled (bool | "auto"): Controls maintenance notifications handling behavior.
460
+ - True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup,
461
+ otherwise a ResponseError is raised.
462
+ - "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are
463
+ gracefully handled - a warning is logged and normal operation continues.
464
+ - False: Maintenance notifications are completely disabled.
465
+ Defaults to "auto".
466
+ proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced.
467
+ Defaults to True.
468
+ relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance.
469
+ If -1 is provided - the relaxed timeout is disabled. Defaults to 20.
470
+ endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS.
471
+ If None, the endpoint type will be automatically determined based on the host and TLS configuration.
472
+ Defaults to None.
473
+
474
+ Raises:
475
+ ValueError: If endpoint_type is provided but is not a valid endpoint type.
476
+ """
477
+ self.enabled = enabled
478
+ self.relaxed_timeout = relaxed_timeout
479
+ self.proactive_reconnect = proactive_reconnect
480
+ self.endpoint_type = endpoint_type
481
+
482
+ def __repr__(self) -> str:
483
+ return (
484
+ f"{self.__class__.__name__}("
485
+ f"enabled={self.enabled}, "
486
+ f"proactive_reconnect={self.proactive_reconnect}, "
487
+ f"relaxed_timeout={self.relaxed_timeout}, "
488
+ f"endpoint_type={self.endpoint_type!r}"
489
+ f")"
490
+ )
491
+
492
+ def is_relaxed_timeouts_enabled(self) -> bool:
493
+ """
494
+ Check if the relaxed_timeout is enabled. The '-1' value is used to disable the relaxed_timeout.
495
+ If relaxed_timeout is set to None, it will make the operation blocking
496
+ and waiting until any response is received.
497
+
498
+ Returns:
499
+ True if the relaxed_timeout is enabled, False otherwise.
500
+ """
501
+ return self.relaxed_timeout != -1
502
+
503
+ def get_endpoint_type(
504
+ self, host: str, connection: "ConnectionInterface"
505
+ ) -> EndpointType:
506
+ """
507
+ Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command.
508
+
509
+ Logic:
510
+ 1. If endpoint_type is explicitly set, use it
511
+ 2. Otherwise, check the original host from connection.host:
512
+ - If host is an IP address, use it directly to determine internal-ip vs external-ip
513
+ - If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn
514
+
515
+ Args:
516
+ host: User provided hostname to analyze
517
+ connection: The connection object to analyze for endpoint type determination
518
+
519
+ Returns:
520
+ """
521
+
522
+ # If endpoint_type is explicitly set, use it
523
+ if self.endpoint_type is not None:
524
+ return self.endpoint_type
525
+
526
+ # Check if the host is an IP address
527
+ try:
528
+ ip_addr = ipaddress.ip_address(host)
529
+ # Host is an IP address - use it directly
530
+ is_private = ip_addr.is_private
531
+ return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP
532
+ except ValueError:
533
+ # Host is an FQDN - need to check resolved IP to determine internal vs external
534
+ pass
535
+
536
+ # Host is an FQDN, get the resolved IP to determine if it's internal or external
537
+ resolved_ip = connection.get_resolved_ip()
538
+
539
+ if resolved_ip:
540
+ try:
541
+ ip_addr = ipaddress.ip_address(resolved_ip)
542
+ is_private = ip_addr.is_private
543
+ # Use FQDN types since the original host was an FQDN
544
+ return (
545
+ EndpointType.INTERNAL_FQDN
546
+ if is_private
547
+ else EndpointType.EXTERNAL_FQDN
548
+ )
549
+ except ValueError:
550
+ # This shouldn't happen since we got the IP from the socket, but fallback
551
+ pass
552
+
553
+ # Final fallback: use heuristics on the FQDN itself
554
+ is_private = _is_private_fqdn(host)
555
+ return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN
556
+
557
+
558
+ class MaintNotificationsPoolHandler:
559
+ def __init__(
560
+ self,
561
+ pool: Union["ConnectionPool", "BlockingConnectionPool"],
562
+ config: MaintNotificationsConfig,
563
+ ) -> None:
564
+ self.pool = pool
565
+ self.config = config
566
+ self._processed_notifications = set()
567
+ self._lock = threading.RLock()
568
+ self.connection = None
569
+
570
+ def set_connection(self, connection: "ConnectionInterface"):
571
+ self.connection = connection
572
+
573
+ def remove_expired_notifications(self):
574
+ with self._lock:
575
+ for notification in tuple(self._processed_notifications):
576
+ if notification.is_expired():
577
+ self._processed_notifications.remove(notification)
578
+
579
+ def handle_notification(self, notification: MaintenanceNotification):
580
+ self.remove_expired_notifications()
581
+
582
+ if isinstance(notification, NodeMovingNotification):
583
+ return self.handle_node_moving_notification(notification)
584
+ else:
585
+ logging.error(f"Unhandled notification type: {notification}")
586
+
587
+ def handle_node_moving_notification(self, notification: NodeMovingNotification):
588
+ if (
589
+ not self.config.proactive_reconnect
590
+ and not self.config.is_relaxed_timeouts_enabled()
591
+ ):
592
+ return
593
+ with self._lock:
594
+ if notification in self._processed_notifications:
595
+ # nothing to do in the connection pool handling
596
+ # the notification has already been handled or is expired
597
+ # just return
598
+ return
599
+
600
+ with self.pool._lock:
601
+ if (
602
+ self.config.proactive_reconnect
603
+ or self.config.is_relaxed_timeouts_enabled()
604
+ ):
605
+ # Get the current connected address - if any
606
+ # This is the address that is being moved
607
+ # and we need to handle only connections
608
+ # connected to the same address
609
+ moving_address_src = (
610
+ self.connection.getpeername() if self.connection else None
611
+ )
612
+
613
+ if getattr(self.pool, "set_in_maintenance", False):
614
+ # Set pool in maintenance mode - executed only if
615
+ # BlockingConnectionPool is used
616
+ self.pool.set_in_maintenance(True)
617
+
618
+ # Update maintenance state, timeout and optionally host address
619
+ # connection settings for matching connections
620
+ self.pool.update_connections_settings(
621
+ state=MaintenanceState.MOVING,
622
+ maintenance_notification_hash=hash(notification),
623
+ relaxed_timeout=self.config.relaxed_timeout,
624
+ host_address=notification.new_node_host,
625
+ matching_address=moving_address_src,
626
+ matching_pattern="connected_address",
627
+ update_notification_hash=True,
628
+ include_free_connections=True,
629
+ )
630
+
631
+ if self.config.proactive_reconnect:
632
+ if notification.new_node_host is not None:
633
+ self.run_proactive_reconnect(moving_address_src)
634
+ else:
635
+ threading.Timer(
636
+ notification.ttl / 2,
637
+ self.run_proactive_reconnect,
638
+ args=(moving_address_src,),
639
+ ).start()
640
+
641
+ # Update config for new connections:
642
+ # Set state to MOVING
643
+ # update host
644
+ # if relax timeouts are enabled - update timeouts
645
+ kwargs: dict = {
646
+ "maintenance_state": MaintenanceState.MOVING,
647
+ "maintenance_notification_hash": hash(notification),
648
+ }
649
+ if notification.new_node_host is not None:
650
+ # the host is not updated if the new node host is None
651
+ # this happens when the MOVING push notification does not contain
652
+ # the new node host - in this case we only update the timeouts
653
+ kwargs.update(
654
+ {
655
+ "host": notification.new_node_host,
656
+ }
657
+ )
658
+ if self.config.is_relaxed_timeouts_enabled():
659
+ kwargs.update(
660
+ {
661
+ "socket_timeout": self.config.relaxed_timeout,
662
+ "socket_connect_timeout": self.config.relaxed_timeout,
663
+ }
664
+ )
665
+ self.pool.update_connection_kwargs(**kwargs)
666
+
667
+ if getattr(self.pool, "set_in_maintenance", False):
668
+ self.pool.set_in_maintenance(False)
669
+
670
+ threading.Timer(
671
+ notification.ttl,
672
+ self.handle_node_moved_notification,
673
+ args=(notification,),
674
+ ).start()
675
+
676
+ self._processed_notifications.add(notification)
677
+
678
+ def run_proactive_reconnect(self, moving_address_src: Optional[str] = None):
679
+ """
680
+ Run proactive reconnect for the pool.
681
+ Active connections are marked for reconnect after they complete the current command.
682
+ Inactive connections are disconnected and will be connected on next use.
683
+ """
684
+ with self._lock:
685
+ with self.pool._lock:
686
+ # take care for the active connections in the pool
687
+ # mark them for reconnect after they complete the current command
688
+ self.pool.update_active_connections_for_reconnect(
689
+ moving_address_src=moving_address_src,
690
+ )
691
+ # take care for the inactive connections in the pool
692
+ # delete them and create new ones
693
+ self.pool.disconnect_free_connections(
694
+ moving_address_src=moving_address_src,
695
+ )
696
+
697
+ def handle_node_moved_notification(self, notification: NodeMovingNotification):
698
+ """
699
+ Handle the cleanup after a node moving notification expires.
700
+ """
701
+ notification_hash = hash(notification)
702
+
703
+ with self._lock:
704
+ # if the current maintenance_notification_hash in kwargs is not matching the notification
705
+ # it means there has been a new moving notification after this one
706
+ # and we don't need to revert the kwargs yet
707
+ if (
708
+ self.pool.connection_kwargs.get("maintenance_notification_hash")
709
+ == notification_hash
710
+ ):
711
+ orig_host = self.pool.connection_kwargs.get("orig_host_address")
712
+ orig_socket_timeout = self.pool.connection_kwargs.get(
713
+ "orig_socket_timeout"
714
+ )
715
+ orig_connect_timeout = self.pool.connection_kwargs.get(
716
+ "orig_socket_connect_timeout"
717
+ )
718
+ kwargs: dict = {
719
+ "maintenance_state": MaintenanceState.NONE,
720
+ "maintenance_notification_hash": None,
721
+ "host": orig_host,
722
+ "socket_timeout": orig_socket_timeout,
723
+ "socket_connect_timeout": orig_connect_timeout,
724
+ }
725
+ self.pool.update_connection_kwargs(**kwargs)
726
+
727
+ with self.pool._lock:
728
+ reset_relaxed_timeout = self.config.is_relaxed_timeouts_enabled()
729
+ reset_host_address = self.config.proactive_reconnect
730
+
731
+ self.pool.update_connections_settings(
732
+ relaxed_timeout=-1,
733
+ state=MaintenanceState.NONE,
734
+ maintenance_notification_hash=None,
735
+ matching_notification_hash=notification_hash,
736
+ matching_pattern="notification_hash",
737
+ update_notification_hash=True,
738
+ reset_relaxed_timeout=reset_relaxed_timeout,
739
+ reset_host_address=reset_host_address,
740
+ include_free_connections=True,
741
+ )
742
+
743
+
744
+ class MaintNotificationsConnectionHandler:
745
+ # 1 = "starting maintenance" notifications, 0 = "completed maintenance" notifications
746
+ _NOTIFICATION_TYPES: dict[type["MaintenanceNotification"], int] = {
747
+ NodeMigratingNotification: 1,
748
+ NodeFailingOverNotification: 1,
749
+ NodeMigratedNotification: 0,
750
+ NodeFailedOverNotification: 0,
751
+ }
752
+
753
+ def __init__(
754
+ self, connection: "ConnectionInterface", config: MaintNotificationsConfig
755
+ ) -> None:
756
+ self.connection = connection
757
+ self.config = config
758
+
759
+ def handle_notification(self, notification: MaintenanceNotification):
760
+ # get the notification type by checking its class in the _NOTIFICATION_TYPES dict
761
+ notification_type = self._NOTIFICATION_TYPES.get(notification.__class__, None)
762
+
763
+ if notification_type is None:
764
+ logging.error(f"Unhandled notification type: {notification}")
765
+ return
766
+
767
+ if notification_type:
768
+ self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE)
769
+ else:
770
+ self.handle_maintenance_completed_notification()
771
+
772
+ def handle_maintenance_start_notification(
773
+ self, maintenance_state: MaintenanceState
774
+ ):
775
+ if (
776
+ self.connection.maintenance_state == MaintenanceState.MOVING
777
+ or not self.config.is_relaxed_timeouts_enabled()
778
+ ):
779
+ return
780
+
781
+ self.connection.maintenance_state = maintenance_state
782
+ self.connection.set_tmp_settings(
783
+ tmp_relaxed_timeout=self.config.relaxed_timeout
784
+ )
785
+ # extend the timeout for all created connections
786
+ self.connection.update_current_socket_timeout(self.config.relaxed_timeout)
787
+
788
+ def handle_maintenance_completed_notification(self):
789
+ # Only reset timeouts if state is not MOVING and relaxed timeouts are enabled
790
+ if (
791
+ self.connection.maintenance_state == MaintenanceState.MOVING
792
+ or not self.config.is_relaxed_timeouts_enabled()
793
+ ):
794
+ return
795
+ self.connection.reset_tmp_settings(reset_relaxed_timeout=True)
796
+ # Maintenance completed - reset the connection
797
+ # timeouts by providing -1 as the relaxed timeout
798
+ self.connection.update_current_socket_timeout(-1)
799
+ self.connection.maintenance_state = MaintenanceState.NONE