coredis 5.5.0__cp314-cp314-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. 22fe76227e35f92ab5c3__mypyc.cpython-314-darwin.so +0 -0
  2. coredis/__init__.py +42 -0
  3. coredis/_enum.py +42 -0
  4. coredis/_json.py +11 -0
  5. coredis/_packer.cpython-314-darwin.so +0 -0
  6. coredis/_packer.py +71 -0
  7. coredis/_protocols.py +50 -0
  8. coredis/_py_311_typing.py +20 -0
  9. coredis/_py_312_typing.py +17 -0
  10. coredis/_sidecar.py +114 -0
  11. coredis/_utils.cpython-314-darwin.so +0 -0
  12. coredis/_utils.py +440 -0
  13. coredis/_version.py +34 -0
  14. coredis/_version.pyi +1 -0
  15. coredis/cache.py +801 -0
  16. coredis/client/__init__.py +6 -0
  17. coredis/client/basic.py +1240 -0
  18. coredis/client/cluster.py +1265 -0
  19. coredis/commands/__init__.py +64 -0
  20. coredis/commands/_key_spec.py +517 -0
  21. coredis/commands/_utils.py +108 -0
  22. coredis/commands/_validators.py +159 -0
  23. coredis/commands/_wrappers.py +175 -0
  24. coredis/commands/bitfield.py +110 -0
  25. coredis/commands/constants.py +662 -0
  26. coredis/commands/core.py +8484 -0
  27. coredis/commands/function.py +408 -0
  28. coredis/commands/monitor.py +168 -0
  29. coredis/commands/pubsub.py +905 -0
  30. coredis/commands/request.py +108 -0
  31. coredis/commands/script.py +296 -0
  32. coredis/commands/sentinel.py +246 -0
  33. coredis/config.py +50 -0
  34. coredis/connection.py +906 -0
  35. coredis/constants.cpython-314-darwin.so +0 -0
  36. coredis/constants.py +37 -0
  37. coredis/credentials.py +45 -0
  38. coredis/exceptions.py +360 -0
  39. coredis/experimental/__init__.py +1 -0
  40. coredis/globals.py +23 -0
  41. coredis/modules/__init__.py +121 -0
  42. coredis/modules/autocomplete.py +138 -0
  43. coredis/modules/base.py +262 -0
  44. coredis/modules/filters.py +1319 -0
  45. coredis/modules/graph.py +362 -0
  46. coredis/modules/json.py +691 -0
  47. coredis/modules/response/__init__.py +0 -0
  48. coredis/modules/response/_callbacks/__init__.py +0 -0
  49. coredis/modules/response/_callbacks/autocomplete.py +42 -0
  50. coredis/modules/response/_callbacks/graph.py +237 -0
  51. coredis/modules/response/_callbacks/json.py +21 -0
  52. coredis/modules/response/_callbacks/search.py +221 -0
  53. coredis/modules/response/_callbacks/timeseries.py +158 -0
  54. coredis/modules/response/types.py +179 -0
  55. coredis/modules/search.py +1089 -0
  56. coredis/modules/timeseries.py +1139 -0
  57. coredis/parser.cpython-314-darwin.so +0 -0
  58. coredis/parser.py +344 -0
  59. coredis/pipeline.py +1225 -0
  60. coredis/pool/__init__.py +11 -0
  61. coredis/pool/basic.py +453 -0
  62. coredis/pool/cluster.py +517 -0
  63. coredis/pool/nodemanager.py +340 -0
  64. coredis/py.typed +0 -0
  65. coredis/recipes/__init__.py +0 -0
  66. coredis/recipes/credentials/__init__.py +5 -0
  67. coredis/recipes/credentials/iam_provider.py +63 -0
  68. coredis/recipes/locks/__init__.py +5 -0
  69. coredis/recipes/locks/extend.lua +17 -0
  70. coredis/recipes/locks/lua_lock.py +281 -0
  71. coredis/recipes/locks/release.lua +10 -0
  72. coredis/response/__init__.py +5 -0
  73. coredis/response/_callbacks/__init__.py +538 -0
  74. coredis/response/_callbacks/acl.py +32 -0
  75. coredis/response/_callbacks/cluster.py +183 -0
  76. coredis/response/_callbacks/command.py +86 -0
  77. coredis/response/_callbacks/connection.py +31 -0
  78. coredis/response/_callbacks/geo.py +58 -0
  79. coredis/response/_callbacks/hash.py +85 -0
  80. coredis/response/_callbacks/keys.py +59 -0
  81. coredis/response/_callbacks/module.py +33 -0
  82. coredis/response/_callbacks/script.py +85 -0
  83. coredis/response/_callbacks/sentinel.py +179 -0
  84. coredis/response/_callbacks/server.py +241 -0
  85. coredis/response/_callbacks/sets.py +44 -0
  86. coredis/response/_callbacks/sorted_set.py +204 -0
  87. coredis/response/_callbacks/streams.py +185 -0
  88. coredis/response/_callbacks/strings.py +70 -0
  89. coredis/response/_callbacks/vector_sets.py +159 -0
  90. coredis/response/_utils.py +33 -0
  91. coredis/response/types.py +416 -0
  92. coredis/retry.py +233 -0
  93. coredis/sentinel.py +477 -0
  94. coredis/stream.py +369 -0
  95. coredis/tokens.py +2286 -0
  96. coredis/typing.py +593 -0
  97. coredis-5.5.0.dist-info/METADATA +211 -0
  98. coredis-5.5.0.dist-info/RECORD +100 -0
  99. coredis-5.5.0.dist-info/WHEEL +6 -0
  100. coredis-5.5.0.dist-info/licenses/LICENSE +23 -0
coredis/connection.py ADDED
@@ -0,0 +1,906 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import dataclasses
5
+ import functools
6
+ import inspect
7
+ import itertools
8
+ import os
9
+ import socket
10
+ import ssl
11
+ import time
12
+ import warnings
13
+ import weakref
14
+ from collections import defaultdict, deque
15
+ from contextlib import suppress
16
+ from typing import TYPE_CHECKING, Any, cast
17
+
18
+ import async_timeout
19
+
20
+ import coredis
21
+ from coredis._packer import Packer
22
+ from coredis._utils import nativestr
23
+ from coredis.credentials import (
24
+ AbstractCredentialProvider,
25
+ UserPass,
26
+ UserPassCredentialProvider,
27
+ )
28
+ from coredis.exceptions import (
29
+ AuthenticationRequiredError,
30
+ ConnectionError,
31
+ RedisError,
32
+ TimeoutError,
33
+ UnknownCommandError,
34
+ )
35
+ from coredis.parser import NotEnoughData, Parser
36
+ from coredis.tokens import PureToken
37
+ from coredis.typing import (
38
+ Awaitable,
39
+ Callable,
40
+ ClassVar,
41
+ Literal,
42
+ RedisValueT,
43
+ ResponseType,
44
+ TypeVar,
45
+ )
46
+
47
+ R = TypeVar("R")
48
+
49
+ if TYPE_CHECKING:
50
+ from coredis.pool.nodemanager import ManagedNode
51
+
52
+
53
+ @dataclasses.dataclass
54
+ class Request:
55
+ connection: weakref.ProxyType[Connection]
56
+ command: bytes
57
+ decode: bool
58
+ encoding: str | None = None
59
+ raise_exceptions: bool = True
60
+ future: asyncio.Future[ResponseType] = dataclasses.field(
61
+ default_factory=lambda: asyncio.get_running_loop().create_future()
62
+ )
63
+ created_at: float = dataclasses.field(default_factory=lambda: time.time())
64
+
65
+ def __post_init__(self) -> None:
66
+ self.future.add_done_callback(self.cleanup)
67
+
68
+ def cleanup(self, future: asyncio.Future[ResponseType]) -> None:
69
+ if future.cancelled() and self.connection and self.connection.is_connected:
70
+ self.connection.disconnect()
71
+
72
+ def enforce_deadline(self, timeout: float) -> None:
73
+ if not self.future.done():
74
+ self.future.set_exception(
75
+ TimeoutError(f"command {nativestr(self.command)} timed out after {timeout} seconds")
76
+ )
77
+
78
+
79
+ @dataclasses.dataclass
80
+ class CommandInvocation:
81
+ command: bytes
82
+ args: tuple[RedisValueT, ...]
83
+ decode: bool | None
84
+ encoding: str | None
85
+
86
+
87
+ class RedisSSLContext:
88
+ context: ssl.SSLContext | None
89
+
90
+ def __init__(
91
+ self,
92
+ keyfile: str | None,
93
+ certfile: str | None,
94
+ cert_reqs: str | ssl.VerifyMode | None = None,
95
+ ca_certs: str | None = None,
96
+ check_hostname: bool | None = None,
97
+ ) -> None:
98
+ self.keyfile = keyfile
99
+ self.certfile = certfile
100
+ self.check_hostname = check_hostname if check_hostname is not None else False
101
+ if cert_reqs is None:
102
+ self.cert_reqs = ssl.CERT_OPTIONAL
103
+ elif isinstance(cert_reqs, str):
104
+ CERT_REQS = {
105
+ "none": ssl.CERT_NONE,
106
+ "optional": ssl.CERT_OPTIONAL,
107
+ "required": ssl.CERT_REQUIRED,
108
+ }
109
+
110
+ self.cert_reqs = CERT_REQS[cert_reqs]
111
+ else:
112
+ self.cert_reqs = cert_reqs
113
+ self.ca_certs = ca_certs
114
+ self.context = None
115
+
116
+ def get(self) -> ssl.SSLContext:
117
+ if not self.context:
118
+ self.context = ssl.create_default_context()
119
+ if self.certfile and self.keyfile:
120
+ self.context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
121
+ if self.ca_certs:
122
+ self.context.load_verify_locations(
123
+ **{("capath" if os.path.isdir(self.ca_certs) else "cafile"): self.ca_certs}
124
+ )
125
+ self.context.check_hostname = self.check_hostname
126
+ self.context.verify_mode = self.cert_reqs
127
+ return self.context
128
+
129
+
130
+ class BaseConnection(asyncio.BaseProtocol):
131
+ """
132
+ Base connection class which implements
133
+ :class:`asyncio.BaseProtocol` to interact
134
+ with the underlying connection established
135
+ with the redis server.
136
+ """
137
+
138
+ #: id for this connection as returned by the redis server
139
+ client_id: int | None
140
+ #: Queue that collects any unread push message types
141
+ push_messages: asyncio.Queue[ResponseType]
142
+ #: client id that the redis server should send any redirected notifications to
143
+ tracking_client_id: int | None
144
+ #: Whether the connection should use RESP or RESP3
145
+ protocol_version: Literal[2, 3]
146
+
147
+ description: ClassVar[str] = "BaseConnection"
148
+ locator: ClassVar[str] = ""
149
+
150
+ #: average response time of requests made on this connection
151
+ average_response_time: float
152
+
153
+ def __init__(
154
+ self,
155
+ stream_timeout: float | None = None,
156
+ encoding: str = "utf-8",
157
+ decode_responses: bool = False,
158
+ *,
159
+ client_name: str | None = None,
160
+ protocol_version: Literal[2, 3] = 3,
161
+ noreply: bool = False,
162
+ noevict: bool = False,
163
+ notouch: bool = False,
164
+ ):
165
+ self._stream_timeout = stream_timeout
166
+ self.username: str | None = None
167
+ self.password: str | None = ""
168
+ self.credential_provider: AbstractCredentialProvider | None = None
169
+ self.db: int | None = None
170
+ self.pid: int = os.getpid()
171
+ self._description_args: Callable[..., dict[str, str | int | None]] = lambda: dict()
172
+ self._connect_callbacks: list[
173
+ (Callable[[BaseConnection], Awaitable[None]] | Callable[[BaseConnection], None])
174
+ ] = list()
175
+ self.encoding = encoding
176
+ self.decode_responses = decode_responses
177
+ self.protocol_version = protocol_version
178
+ self.server_version: str | None = None
179
+ self.client_name = client_name
180
+ self.client_id = None
181
+ self.tracking_client_id = None
182
+
183
+ self.last_active_at: float = time.time()
184
+ self.last_request_processed_at: float | None = None
185
+
186
+ self._transport: asyncio.Transport | None = None
187
+ self._parser = Parser()
188
+ self._read_flag = asyncio.Event()
189
+ self._read_waiters: set[asyncio.Task[bool]] = set()
190
+ self.packer: Packer = Packer(self.encoding)
191
+ self.push_messages: asyncio.Queue[ResponseType] = asyncio.Queue()
192
+
193
+ self.noreply: bool = noreply
194
+ self.noreply_set: bool = False
195
+
196
+ self.noevict: bool = noevict
197
+ self.notouch: bool = notouch
198
+
199
+ self.needs_handshake: bool = True
200
+ self._last_error: BaseException | None = None
201
+ self._connection_error: BaseException | None = None
202
+
203
+ self._requests: deque[Request] = deque()
204
+
205
+ self.average_response_time: float = 0
206
+ self.requests_processed: int = 0
207
+ self._write_ready: asyncio.Event = asyncio.Event()
208
+ self._transport_lock: asyncio.Lock = asyncio.Lock()
209
+
210
+ def __repr__(self) -> str:
211
+ return self.describe(self._description_args())
212
+
213
+ @classmethod
214
+ def describe(cls, description_args: dict[str, Any]) -> str:
215
+ return cls.description.format_map(defaultdict(lambda: None, description_args))
216
+
217
+ @property
218
+ def location(self) -> str:
219
+ return self.locator.format_map(defaultdict(lambda: None, self._description_args()))
220
+
221
+ @property
222
+ def estimated_time_to_idle(self) -> float:
223
+ """
224
+ Estimated time till the pending request queue of this connection
225
+ has been cleared
226
+ """
227
+ return self.requests_pending * self.average_response_time
228
+
229
+ def __del__(self) -> None:
230
+ try:
231
+ self.disconnect()
232
+ except Exception: # noqa
233
+ pass
234
+
235
+ @property
236
+ def is_connected(self) -> bool:
237
+ """
238
+ Whether the connection is established and initial handshakes were
239
+ performed without error
240
+ """
241
+ return self._transport is not None and self._connection_error is None
242
+
243
+ @property
244
+ def requests_pending(self) -> int:
245
+ """
246
+ Number of requests pending response on this connection
247
+ """
248
+ return len(self._requests)
249
+
250
+ @property
251
+ def lag(self) -> float:
252
+ """
253
+ Returns the amount of seconds since the last request was processed
254
+ if there are still in flight requests pending on this connection
255
+ """
256
+ if not self._requests:
257
+ return 0
258
+ elif self.last_request_processed_at is None:
259
+ return time.time()
260
+ else:
261
+ return time.time() - self.last_request_processed_at
262
+
263
+ def register_connect_callback(
264
+ self,
265
+ callback: (Callable[[BaseConnection], None] | Callable[[BaseConnection], Awaitable[None]]),
266
+ ) -> None:
267
+ self._connect_callbacks.append(callback)
268
+
269
+ def clear_connect_callbacks(self) -> None:
270
+ self._connect_callbacks = list()
271
+
272
+ async def can_read(self) -> bool:
273
+ """Checks for data that can be read"""
274
+ assert self._parser
275
+
276
+ if not self.is_connected:
277
+ await self.connect()
278
+
279
+ return self._parser.can_read()
280
+
281
+ async def connect(self) -> None:
282
+ """
283
+ Establish a connnection to the redis server
284
+ and initiate any post connect callbacks
285
+ """
286
+ self._connection_error = None
287
+ try:
288
+ await self._connect()
289
+ except (asyncio.CancelledError, RedisError) as err:
290
+ self._connection_error = err
291
+ raise
292
+ except Exception as err:
293
+ self._connection_error = err
294
+ raise ConnectionError(str(err)) from err
295
+
296
+ # run any user callbacks. right now the only internal callback
297
+ # is for pubsub channel/pattern resubscription
298
+ for callback in self._connect_callbacks:
299
+ task = callback(self)
300
+ if inspect.isawaitable(task):
301
+ await task
302
+
303
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
304
+ """
305
+ :meta private:
306
+ """
307
+ self._transport = cast(asyncio.Transport, transport)
308
+ self._write_ready.set()
309
+
310
+ def connection_lost(self, exc: BaseException | None) -> None:
311
+ """
312
+ :meta private:
313
+ """
314
+ if exc:
315
+ self._last_error = exc
316
+
317
+ self.disconnect()
318
+
319
+ def pause_writing(self) -> None:
320
+ """
321
+ :meta private:
322
+ """
323
+ self._write_ready.clear()
324
+
325
+ def resume_writing(self) -> None:
326
+ """
327
+ :meta private:
328
+ """
329
+ self._write_ready.set()
330
+
331
+ def data_received(self, data: bytes) -> None:
332
+ """
333
+ :meta private:
334
+ """
335
+ self._parser.feed(data)
336
+ self._read_flag.set()
337
+ if not self._requests:
338
+ return
339
+
340
+ request = self._requests.popleft()
341
+ response = self._parser.get_response(request.decode, request.encoding)
342
+ while not isinstance(
343
+ response,
344
+ NotEnoughData,
345
+ ):
346
+ if not (request.future.cancelled() or request.future.done()):
347
+ if request.raise_exceptions and isinstance(response, RedisError):
348
+ request.future.set_exception(response)
349
+ else:
350
+ request.future.set_result(response)
351
+
352
+ self.last_request_processed_at = time.time()
353
+ self.requests_processed += 1
354
+ response_time = time.time() - request.created_at
355
+
356
+ self.average_response_time = (
357
+ (self.average_response_time * (self.requests_processed - 1)) + response_time
358
+ ) / self.requests_processed
359
+
360
+ try:
361
+ request = self._requests.popleft()
362
+ except IndexError:
363
+ return
364
+
365
+ response = self._parser.get_response(request.decode, request.encoding)
366
+
367
+ # In case the first request pulled from the queue doesn't have enough data
368
+ # to process, put it back to the start of the queue for the next iteration
369
+ if request:
370
+ self._requests.appendleft(request)
371
+
372
+ def eof_received(self) -> None:
373
+ """
374
+ :meta private:
375
+ """
376
+ self.disconnect()
377
+
378
+ async def _connect(self) -> None:
379
+ raise NotImplementedError
380
+
381
+ async def update_tracking_client(self, enabled: bool, client_id: int | None = None) -> bool:
382
+ """
383
+ Associate this connection to :paramref:`client_id` to
384
+ relay any tracking notifications to.
385
+ """
386
+ try:
387
+ params: list[RedisValueT] = (
388
+ [b"ON", b"REDIRECT", client_id] if (enabled and client_id is not None) else [b"OFF"]
389
+ )
390
+
391
+ if (
392
+ await (await self.create_request(b"CLIENT TRACKING", *params, decode=False))
393
+ != b"OK"
394
+ ):
395
+ raise ConnectionError("Unable to toggle client tracking")
396
+ self.tracking_client_id = client_id
397
+ return True
398
+ except UnknownCommandError: # noqa
399
+ raise
400
+ except Exception: # noqa
401
+ return False
402
+
403
+ async def try_legacy_auth(self) -> None:
404
+ if self.credential_provider:
405
+ creds = await self.credential_provider.get_credentials()
406
+ params = [creds.password]
407
+ if isinstance(creds, UserPass):
408
+ params.insert(0, creds.username)
409
+ elif not self.password:
410
+ return
411
+ else:
412
+ params = [self.password]
413
+ if self.username:
414
+ params.insert(0, self.username)
415
+ await (await self.create_request(b"AUTH", *params, decode=False))
416
+
417
+ async def perform_handshake(self) -> None:
418
+ if not self.needs_handshake:
419
+ return
420
+
421
+ hello_command_args: list[int | str | bytes] = [self.protocol_version]
422
+ if creds := (
423
+ await self.credential_provider.get_credentials()
424
+ if self.credential_provider
425
+ else (
426
+ await UserPassCredentialProvider(self.username, self.password).get_credentials()
427
+ if (self.username or self.password)
428
+ else None
429
+ )
430
+ ):
431
+ hello_command_args.extend(
432
+ [
433
+ "AUTH",
434
+ creds.username,
435
+ creds.password or b"",
436
+ ]
437
+ )
438
+ try:
439
+ hello_resp = await (
440
+ await self.create_request(b"HELLO", *hello_command_args, decode=False)
441
+ )
442
+ assert isinstance(hello_resp, (list, dict))
443
+ if self.protocol_version == 3:
444
+ resp3 = cast(dict[bytes, RedisValueT], hello_resp)
445
+ assert resp3[b"proto"] == 3
446
+ self.server_version = nativestr(resp3[b"version"])
447
+ self.client_id = int(resp3[b"id"])
448
+ else:
449
+ resp = cast(list[RedisValueT], hello_resp)
450
+ self.server_version = nativestr(resp[3])
451
+ self.client_id = int(resp[7])
452
+ if self.server_version >= "7.2":
453
+ await asyncio.gather(
454
+ await self.create_request(
455
+ b"CLIENT SETINFO",
456
+ b"LIB-NAME",
457
+ b"coredis",
458
+ ),
459
+ await self.create_request(
460
+ b"CLIENT SETINFO",
461
+ b"LIB-VER",
462
+ coredis.__version__,
463
+ ),
464
+ )
465
+ self.needs_handshake = False
466
+ except AuthenticationRequiredError:
467
+ await self.try_legacy_auth()
468
+ self.server_version = None
469
+ self.client_id = None
470
+ except UnknownCommandError: # noqa
471
+ # This should only happen for redis servers < 6 or forks of redis
472
+ # that are not > 6 compliant.
473
+ warning = (
474
+ "The server responded with no support for the `HELLO` command"
475
+ " and therefore a handshake could not be performed"
476
+ )
477
+ if self.protocol_version == 3:
478
+ raise ConnectionError(
479
+ "Unable to use RESP3 due to missing `HELLO` implementation "
480
+ "the server. Use `protocol_version=2` when constructing the client."
481
+ )
482
+ else:
483
+ warnings.warn(warning, category=UserWarning)
484
+ await self.try_legacy_auth()
485
+ self.needs_handshake = False
486
+
487
+ async def on_connect(self) -> None:
488
+ self._parser.on_connect(self)
489
+ await self.perform_handshake()
490
+
491
+ if self.db:
492
+ if await (await self.create_request(b"SELECT", self.db, decode=False)) != b"OK":
493
+ raise ConnectionError(f"Invalid Database {self.db}")
494
+
495
+ if self.client_name is not None:
496
+ if (
497
+ await (await self.create_request(b"CLIENT SETNAME", self.client_name, decode=False))
498
+ != b"OK"
499
+ ):
500
+ raise ConnectionError(f"Failed to set client name: {self.client_name}")
501
+
502
+ if self.noevict:
503
+ await (await self.create_request(b"CLIENT NO-EVICT", b"ON"))
504
+
505
+ if self.notouch:
506
+ await (await self.create_request(b"CLIENT NO-TOUCH", b"ON"))
507
+
508
+ if self.noreply:
509
+ await (await self.create_request(b"CLIENT REPLY", b"OFF", noreply=True))
510
+ self.noreply_set = True
511
+
512
+ self.last_active_at = time.time()
513
+
514
+ async def fetch_push_message(
515
+ self,
516
+ decode: RedisValueT | None = None,
517
+ push_message_types: set[bytes] | None = None,
518
+ block: bool | None = False,
519
+ ) -> ResponseType:
520
+ """
521
+ Read the next pending response
522
+ """
523
+ if not self.is_connected:
524
+ await self.connect()
525
+
526
+ if len(self._requests) > 0:
527
+ raise ConnectionError(
528
+ f"Invalid request for push messages. {len(self._requests)} requests still pending"
529
+ )
530
+
531
+ message = self._parser.get_response(
532
+ bool(decode) if decode is not None else self.decode_responses,
533
+ self.encoding,
534
+ push_message_types,
535
+ )
536
+ while isinstance(
537
+ message,
538
+ NotEnoughData,
539
+ ):
540
+ self._read_flag.clear()
541
+ try:
542
+ timeout = self._stream_timeout if not block else None
543
+ read_ready_task = asyncio.create_task(self._read_flag.wait())
544
+ read_ready_task.add_done_callback(
545
+ lambda _: self._read_waiters.discard(read_ready_task)
546
+ )
547
+ self._read_waiters.add(read_ready_task)
548
+ await asyncio.wait_for(read_ready_task, timeout)
549
+ except asyncio.TimeoutError:
550
+ raise TimeoutError
551
+ except asyncio.CancelledError:
552
+ if not self.is_connected:
553
+ raise ConnectionError("Connection lost")
554
+ raise
555
+ message = self._parser.get_response(
556
+ bool(decode) if decode is not None else self.decode_responses,
557
+ self.encoding,
558
+ push_message_types,
559
+ )
560
+ return message
561
+
562
+ async def _send_packed_command(
563
+ self, command: list[bytes], timeout: float | None = None
564
+ ) -> None:
565
+ """
566
+ Sends an already packed command to the Redis server
567
+ """
568
+
569
+ assert self._transport
570
+ try:
571
+ async with async_timeout.timeout(timeout):
572
+ await self._write_ready.wait()
573
+ except asyncio.TimeoutError:
574
+ if self._transport:
575
+ self.disconnect()
576
+ raise TimeoutError(f"Unable to write after waiting for socket for {timeout} seconds")
577
+ self._transport.writelines(command)
578
+
579
+ async def send_command(
580
+ self,
581
+ command: bytes,
582
+ *args: RedisValueT,
583
+ ) -> None:
584
+ """
585
+ Send a command to the redis server
586
+ """
587
+
588
+ if not self.is_connected:
589
+ await self.connect()
590
+
591
+ await self._send_packed_command(self.packer.pack_command(command, *args))
592
+
593
+ self.last_active_at = time.time()
594
+
595
+ async def create_request(
596
+ self,
597
+ command: bytes,
598
+ *args: RedisValueT,
599
+ noreply: bool | None = None,
600
+ decode: RedisValueT | None = None,
601
+ encoding: str | None = None,
602
+ raise_exceptions: bool = True,
603
+ timeout: float | None = None,
604
+ ) -> asyncio.Future[ResponseType]:
605
+ """
606
+ Send a command to the redis server
607
+ """
608
+ from coredis.commands.constants import CommandName
609
+
610
+ if not self.is_connected:
611
+ await self.connect()
612
+
613
+ cmd_list = []
614
+ request_timeout: float | None = timeout or self._stream_timeout
615
+ if self.is_connected and noreply and not self.noreply:
616
+ cmd_list = self.packer.pack_command(CommandName.CLIENT_REPLY, PureToken.SKIP)
617
+ cmd_list.extend(self.packer.pack_command(command, *args))
618
+ await self._send_packed_command(cmd_list, timeout=request_timeout)
619
+
620
+ self.last_active_at = time.time()
621
+
622
+ if not (self.noreply_set or noreply):
623
+ request = Request(
624
+ weakref.proxy(self),
625
+ command,
626
+ bool(decode) if decode is not None else self.decode_responses,
627
+ encoding or self.encoding,
628
+ raise_exceptions,
629
+ )
630
+ self._requests.append(request)
631
+ if request_timeout is not None:
632
+ asyncio.get_running_loop().call_later(
633
+ request_timeout,
634
+ functools.partial(
635
+ request.enforce_deadline,
636
+ request_timeout,
637
+ ),
638
+ )
639
+ return request.future
640
+ else:
641
+ none: asyncio.Future[ResponseType] = asyncio.Future()
642
+ none.set_result(None)
643
+ return none
644
+
645
+ async def create_requests(
646
+ self,
647
+ commands: list[CommandInvocation],
648
+ raise_exceptions: bool = True,
649
+ timeout: float | None = None,
650
+ ) -> list[asyncio.Future[ResponseType]]:
651
+ """
652
+ Send multiple commands to the redis server
653
+ """
654
+
655
+ if not self.is_connected:
656
+ await self.connect()
657
+
658
+ request_timeout: float | None = timeout or self._stream_timeout
659
+
660
+ await self._send_packed_command(
661
+ self.packer.pack_commands(
662
+ list(itertools.chain((cmd.command, *cmd.args) for cmd in commands))
663
+ ),
664
+ timeout=request_timeout,
665
+ )
666
+
667
+ self.last_active_at = time.time()
668
+ requests: list[asyncio.Future[ResponseType]] = []
669
+ for cmd in commands:
670
+ request = Request(
671
+ weakref.proxy(self),
672
+ cmd.command,
673
+ bool(cmd.decode) if cmd.decode is not None else self.decode_responses,
674
+ cmd.encoding or self.encoding,
675
+ raise_exceptions,
676
+ )
677
+ self._requests.append(request)
678
+ if request_timeout is not None:
679
+ asyncio.get_running_loop().call_later(
680
+ request_timeout,
681
+ functools.partial(request.enforce_deadline, request_timeout),
682
+ )
683
+ requests.append(request.future)
684
+ return requests
685
+
686
+ def disconnect(self) -> None:
687
+ """
688
+ Disconnect from the Redis server
689
+ """
690
+ self.needs_handshake = True
691
+ self.noreply_set = False
692
+ self._parser.on_disconnect()
693
+ if self._transport:
694
+ with suppress(RuntimeError):
695
+ self._transport.close()
696
+
697
+ disconnect_exc = self._last_error or ConnectionError("connection lost")
698
+ while self._read_waiters:
699
+ waiter = self._read_waiters.pop()
700
+ if not waiter.done():
701
+ with suppress(RuntimeError):
702
+ waiter.cancel()
703
+ while True:
704
+ try:
705
+ request = self._requests.popleft()
706
+ if not request.future.done():
707
+ request.future.set_exception(disconnect_exc)
708
+ except IndexError:
709
+ break
710
+ self._transport = None
711
+
712
+
713
+ class Connection(BaseConnection):
714
+ description: ClassVar[str] = "Connection<host={host},port={port},db={db}>"
715
+ locator: ClassVar[str] = "host={host},port={port}"
716
+
717
+ def __init__(
718
+ self,
719
+ host: str = "127.0.0.1",
720
+ port: int = 6379,
721
+ username: str | None = None,
722
+ password: str | None = None,
723
+ credential_provider: AbstractCredentialProvider | None = None,
724
+ db: int | None = 0,
725
+ stream_timeout: float | None = None,
726
+ connect_timeout: float | None = None,
727
+ ssl_context: ssl.SSLContext | None = None,
728
+ encoding: str = "utf-8",
729
+ decode_responses: bool = False,
730
+ socket_keepalive: bool | None = None,
731
+ socket_keepalive_options: dict[int, int | bytes] | None = None,
732
+ *,
733
+ client_name: str | None = None,
734
+ protocol_version: Literal[2, 3] = 3,
735
+ noreply: bool = False,
736
+ noevict: bool = False,
737
+ notouch: bool = False,
738
+ ):
739
+ super().__init__(
740
+ stream_timeout,
741
+ encoding,
742
+ decode_responses,
743
+ client_name=client_name,
744
+ protocol_version=protocol_version,
745
+ noreply=noreply,
746
+ noevict=noevict,
747
+ notouch=notouch,
748
+ )
749
+ self.host = host
750
+ self.port = port
751
+ self.username: str | None = username
752
+ self.password: str | None = password
753
+ self.credential_provider: AbstractCredentialProvider | None = credential_provider
754
+ self.db: int | None = db
755
+ self.ssl_context = ssl_context
756
+ self._connect_timeout = connect_timeout
757
+ self._description_args: Callable[..., dict[str, str | int | None]] = lambda: {
758
+ "host": self.host,
759
+ "port": self.port,
760
+ "db": self.db,
761
+ }
762
+ self.socket_keepalive = socket_keepalive
763
+ self.socket_keepalive_options: dict[int, int | bytes] = socket_keepalive_options or {}
764
+
765
+ async def _connect(self) -> None:
766
+ async with self._transport_lock:
767
+ if self._transport:
768
+ return
769
+ if self.ssl_context:
770
+ connection = asyncio.get_running_loop().create_connection(
771
+ lambda: self, host=self.host, port=self.port, ssl=self.ssl_context
772
+ )
773
+ else:
774
+ connection = asyncio.get_running_loop().create_connection(
775
+ lambda: self, host=self.host, port=self.port
776
+ )
777
+
778
+ try:
779
+ async with async_timeout.timeout(self._connect_timeout):
780
+ transport, _ = await connection
781
+ except asyncio.TimeoutError:
782
+ raise ConnectionError(
783
+ f"Unable to establish a connection within {self._connect_timeout} seconds"
784
+ )
785
+ sock = transport.get_extra_info("socket")
786
+ if sock is not None:
787
+ try:
788
+ # TCP_KEEPALIVE
789
+ if self.socket_keepalive:
790
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
791
+
792
+ for k, v in self.socket_keepalive_options.items():
793
+ sock.setsockopt(socket.SOL_TCP, k, v)
794
+ except (OSError, TypeError):
795
+ # `socket_keepalive_options` might contain invalid options
796
+ # causing an error
797
+ transport.close()
798
+ raise
799
+ await self.on_connect()
800
+
801
+
802
+ class UnixDomainSocketConnection(BaseConnection):
803
+ description: ClassVar[str] = "UnixDomainSocketConnection<path={path},db={db}>"
804
+ locator: ClassVar[str] = "path={path}"
805
+
806
+ def __init__(
807
+ self,
808
+ path: str = "",
809
+ username: str | None = None,
810
+ password: str | None = None,
811
+ credential_provider: AbstractCredentialProvider | None = None,
812
+ db: int = 0,
813
+ stream_timeout: float | None = None,
814
+ connect_timeout: float | None = None,
815
+ encoding: str = "utf-8",
816
+ decode_responses: bool = False,
817
+ *,
818
+ client_name: str | None = None,
819
+ protocol_version: Literal[2, 3] = 3,
820
+ **_: RedisValueT,
821
+ ) -> None:
822
+ super().__init__(
823
+ stream_timeout,
824
+ encoding,
825
+ decode_responses,
826
+ client_name=client_name,
827
+ protocol_version=protocol_version,
828
+ )
829
+ self.path = path
830
+ self.db = db
831
+ self.username = username
832
+ self.password = password
833
+ self.credential_provider = credential_provider
834
+ self._connect_timeout = connect_timeout
835
+ self._description_args = lambda: {"path": self.path, "db": self.db}
836
+
837
+ async def _connect(self) -> None:
838
+ async with async_timeout.timeout(self._connect_timeout):
839
+ await asyncio.get_running_loop().create_unix_connection(lambda: self, path=self.path)
840
+
841
+ await self.on_connect()
842
+
843
+
844
+ class ClusterConnection(Connection):
845
+ "Manages TCP communication to and from a Redis server"
846
+
847
+ description: ClassVar[str] = "ClusterConnection<host={host},port={port}>"
848
+ locator: ClassVar[str] = "host={host},port={port}"
849
+ node: ManagedNode
850
+
851
+ def __init__(
852
+ self,
853
+ host: str = "127.0.0.1",
854
+ port: int = 6379,
855
+ username: str | None = None,
856
+ password: str | None = None,
857
+ credential_provider: AbstractCredentialProvider | None = None,
858
+ db: int | None = 0,
859
+ stream_timeout: float | None = None,
860
+ connect_timeout: float | None = None,
861
+ ssl_context: ssl.SSLContext | None = None,
862
+ encoding: str = "utf-8",
863
+ decode_responses: bool = False,
864
+ socket_keepalive: bool | None = None,
865
+ socket_keepalive_options: dict[int, int | bytes] | None = None,
866
+ *,
867
+ client_name: str | None = None,
868
+ protocol_version: Literal[2, 3] = 3,
869
+ read_from_replicas: bool = False,
870
+ noreply: bool = False,
871
+ noevict: bool = False,
872
+ notouch: bool = False,
873
+ ) -> None:
874
+ self.read_from_replicas = read_from_replicas
875
+ super().__init__(
876
+ host=host,
877
+ port=port,
878
+ username=username,
879
+ password=password,
880
+ credential_provider=credential_provider,
881
+ db=db,
882
+ stream_timeout=stream_timeout,
883
+ connect_timeout=connect_timeout,
884
+ ssl_context=ssl_context,
885
+ encoding=encoding,
886
+ decode_responses=decode_responses,
887
+ socket_keepalive=socket_keepalive,
888
+ socket_keepalive_options=socket_keepalive_options,
889
+ client_name=client_name,
890
+ protocol_version=protocol_version,
891
+ noreply=noreply,
892
+ noevict=noevict,
893
+ notouch=notouch,
894
+ )
895
+
896
+ async def on_connect(self) -> None:
897
+ """
898
+ Initialize the connection, authenticate and select a database and send
899
+ `READONLY` if `read_from_replicas` is set during initialization.
900
+
901
+ :meta private:
902
+ """
903
+
904
+ await super().on_connect()
905
+ if self.read_from_replicas:
906
+ assert (await (await self.create_request(b"READONLY", decode=False))) == b"OK"